diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 94ec021b..83c1c6fd 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -166,8 +166,10 @@ def load_lora(name, filename): module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.MultiheadAttention: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.Conv2d: + elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) else: print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') continue @@ -233,6 +235,8 @@ def lora_calc_updown(lora, module, target): if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) else: updown = up @ down