Lora: add an option to use old method of applying loras
This commit is contained in:
parent
083dc3c76a
commit
ec0da07236
@ -245,6 +245,19 @@ def lora_calc_updown(lora, module, target):
|
|||||||
return updown
|
return updown
|
||||||
|
|
||||||
|
|
||||||
|
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
|
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||||
|
|
||||||
|
if weights_backup is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
|
else:
|
||||||
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
|
||||||
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of Loras to the weights of torch layer self.
|
Applies the currently selected set of Loras to the weights of torch layer self.
|
||||||
@ -269,12 +282,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|||||||
self.lora_weights_backup = weights_backup
|
self.lora_weights_backup = weights_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
if current_names != wanted_names:
|
||||||
if weights_backup is not None:
|
lora_restore_weights_from_backup(self)
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
|
||||||
self.in_proj_weight.copy_(weights_backup[0])
|
|
||||||
self.out_proj.weight.copy_(weights_backup[1])
|
|
||||||
else:
|
|
||||||
self.weight.copy_(weights_backup)
|
|
||||||
|
|
||||||
for lora in loaded_loras:
|
for lora in loaded_loras:
|
||||||
module = lora.modules.get(lora_layer_name, None)
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
@ -305,12 +313,45 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|||||||
setattr(self, "lora_current_names", wanted_names)
|
setattr(self, "lora_current_names", wanted_names)
|
||||||
|
|
||||||
|
|
||||||
|
def lora_forward(module, input, original_forward):
|
||||||
|
"""
|
||||||
|
Old way of applying Lora by executing operations during layer's forward.
|
||||||
|
Stacking many loras this way results in big performance degradation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(loaded_loras) == 0:
|
||||||
|
return original_forward(module, input)
|
||||||
|
|
||||||
|
input = devices.cond_cast_unet(input)
|
||||||
|
|
||||||
|
lora_restore_weights_from_backup(module)
|
||||||
|
lora_reset_cached_weight(module)
|
||||||
|
|
||||||
|
res = original_forward(module, input)
|
||||||
|
|
||||||
|
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
||||||
|
for lora in loaded_loras:
|
||||||
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
|
if module is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module.up.to(device=devices.device)
|
||||||
|
module.down.to(device=devices.device)
|
||||||
|
|
||||||
|
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||||
setattr(self, "lora_current_names", ())
|
setattr(self, "lora_current_names", ())
|
||||||
setattr(self, "lora_weights_backup", None)
|
setattr(self, "lora_weights_backup", None)
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_forward(self, input):
|
def lora_Linear_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
||||||
|
|
||||||
lora_apply_weights(self)
|
lora_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Linear_forward_before_lora(self, input)
|
return torch.nn.Linear_forward_before_lora(self, input)
|
||||||
@ -323,6 +364,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def lora_Conv2d_forward(self, input):
|
def lora_Conv2d_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
|
||||||
|
|
||||||
lora_apply_weights(self)
|
lora_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_forward_before_lora(self, input)
|
return torch.nn.Conv2d_forward_before_lora(self, input)
|
||||||
|
@ -55,3 +55,8 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
|||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
||||||
|
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||||
|
}))
|
||||||
|
Loading…
Reference in New Issue
Block a user