add support for loras trained on kohya's scripts 0.4.0 (alphas)
This commit is contained in:
parent
e8c3d03f7d
commit
e407d1af89
@ -92,6 +92,15 @@ def load_lora(name, filename):
|
|||||||
keys_failed_to_match.append(key_diffusers)
|
keys_failed_to_match.append(key_diffusers)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
lora_module = lora.modules.get(key, None)
|
||||||
|
if lora_module is None:
|
||||||
|
lora_module = LoraUpDownModule()
|
||||||
|
lora.modules[key] = lora_module
|
||||||
|
|
||||||
|
if lora_key == "alpha":
|
||||||
|
lora_module.alpha = weight.item()
|
||||||
|
continue
|
||||||
|
|
||||||
if type(sd_module) == torch.nn.Linear:
|
if type(sd_module) == torch.nn.Linear:
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.Conv2d:
|
elif type(sd_module) == torch.nn.Conv2d:
|
||||||
@ -104,17 +113,12 @@ def load_lora(name, filename):
|
|||||||
|
|
||||||
module.to(device=devices.device, dtype=devices.dtype)
|
module.to(device=devices.device, dtype=devices.dtype)
|
||||||
|
|
||||||
lora_module = lora.modules.get(key, None)
|
|
||||||
if lora_module is None:
|
|
||||||
lora_module = LoraUpDownModule()
|
|
||||||
lora.modules[key] = lora_module
|
|
||||||
|
|
||||||
if lora_key == "lora_up.weight":
|
if lora_key == "lora_up.weight":
|
||||||
lora_module.up = module
|
lora_module.up = module
|
||||||
elif lora_key == "lora_down.weight":
|
elif lora_key == "lora_down.weight":
|
||||||
lora_module.down = module
|
lora_module.down = module
|
||||||
else:
|
else:
|
||||||
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight'
|
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
||||||
|
|
||||||
if len(keys_failed_to_match) > 0:
|
if len(keys_failed_to_match) > 0:
|
||||||
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
||||||
@ -161,7 +165,7 @@ def lora_forward(module, input, res):
|
|||||||
for lora in loaded_loras:
|
for lora in loaded_loras:
|
||||||
module = lora.modules.get(lora_layer_name, None)
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
if module is not None:
|
if module is not None:
|
||||||
res = res + module.up(module.down(input)) * lora.multiplier
|
res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1]
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user