lora extension rework to include other types of networks
This commit is contained in:
parent
7d26c479ee
commit
b75b004fe6
@ -1,5 +1,5 @@
|
||||
from modules import extra_networks, shared
|
||||
import lora
|
||||
import networks
|
||||
|
||||
|
||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
@ -9,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_lora
|
||||
|
||||
if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
|
||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
@ -21,12 +21,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
names.append(params.items[0])
|
||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||
|
||||
lora.load_loras(names, multipliers)
|
||||
networks.load_networks(names, multipliers)
|
||||
|
||||
if shared.opts.lora_add_hashes_to_infotext:
|
||||
lora_hashes = []
|
||||
for item in lora.loaded_loras:
|
||||
shorthash = item.lora_on_disk.shorthash
|
||||
network_hashes = []
|
||||
for item in networks.loaded_networks:
|
||||
shorthash = item.network_on_disk.shorthash
|
||||
if not shorthash:
|
||||
continue
|
||||
|
||||
@ -36,10 +36,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
|
||||
alias = alias.replace(":", "").replace(",", "")
|
||||
|
||||
lora_hashes.append(f"{alias}: {shorthash}")
|
||||
network_hashes.append(f"{alias}: {shorthash}")
|
||||
|
||||
if lora_hashes:
|
||||
p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
|
||||
if network_hashes:
|
||||
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
||||
|
@ -1,537 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
from typing import Union
|
||||
|
||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache
|
||||
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||
re_compiled = {}
|
||||
|
||||
suffix_conversion = {
|
||||
"attentions": {},
|
||||
"resnets": {
|
||||
"conv1": "in_layers_2",
|
||||
"conv2": "out_layers_3",
|
||||
"time_emb_proj": "emb_layers_1",
|
||||
"conv_shortcut": "skip_connection",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||
def match(match_list, regex_text):
|
||||
regex = re_compiled.get(regex_text)
|
||||
if regex is None:
|
||||
regex = re.compile(regex_text)
|
||||
re_compiled[regex_text] = regex
|
||||
|
||||
r = re.match(regex, key)
|
||||
if not r:
|
||||
return False
|
||||
|
||||
match_list.clear()
|
||||
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||
return True
|
||||
|
||||
m = []
|
||||
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
||||
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
||||
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
||||
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
||||
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
||||
|
||||
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
if is_sd2:
|
||||
if 'mlp_fc1' in m[1]:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||
elif 'mlp_fc2' in m[1]:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||
else:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||
|
||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||
|
||||
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
if 'mlp_fc1' in m[1]:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||
elif 'mlp_fc2' in m[1]:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||
else:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||
|
||||
return key
|
||||
|
||||
|
||||
class LoraOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
self.metadata = {}
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
def read_metadata():
|
||||
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
|
||||
|
||||
return metadata
|
||||
|
||||
if self.is_safetensors:
|
||||
try:
|
||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading lora {filename}")
|
||||
|
||||
if self.metadata:
|
||||
m = {}
|
||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||
m[k] = v
|
||||
|
||||
self.metadata = m
|
||||
|
||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||
|
||||
self.hash = None
|
||||
self.shorthash = None
|
||||
self.set_hash(
|
||||
self.metadata.get('sshs_model_hash') or
|
||||
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
||||
''
|
||||
)
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v
|
||||
self.shorthash = self.hash[0:12]
|
||||
|
||||
if self.shorthash:
|
||||
available_lora_hash_lookup[self.shorthash] = self
|
||||
|
||||
def read_hash(self):
|
||||
if not self.hash:
|
||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||
|
||||
def get_alias(self):
|
||||
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
|
||||
return self.name
|
||||
else:
|
||||
return self.alias
|
||||
|
||||
|
||||
class LoraModule:
|
||||
def __init__(self, name, lora_on_disk: LoraOnDisk):
|
||||
self.name = name
|
||||
self.lora_on_disk = lora_on_disk
|
||||
self.multiplier = 1.0
|
||||
self.modules = {}
|
||||
self.mtime = None
|
||||
|
||||
self.mentioned_name = None
|
||||
"""the text that was used to add lora to prompt - can be either name or an alias"""
|
||||
|
||||
|
||||
class LoraUpDownModule:
|
||||
def __init__(self):
|
||||
self.up = None
|
||||
self.down = None
|
||||
self.alpha = None
|
||||
|
||||
|
||||
def assign_lora_names_to_compvis_modules(sd_model):
|
||||
lora_layer_mapping = {}
|
||||
|
||||
if shared.sd_model.is_sdxl:
|
||||
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
||||
if not hasattr(embedder, 'wrapped'):
|
||||
continue
|
||||
|
||||
for name, module in embedder.wrapped.named_modules():
|
||||
lora_name = f'{i}_{name.replace(".", "_")}'
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
else:
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
|
||||
for name, module in shared.sd_model.model.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
|
||||
sd_model.lora_layer_mapping = lora_layer_mapping
|
||||
|
||||
|
||||
def load_lora(name, lora_on_disk):
|
||||
lora = LoraModule(name, lora_on_disk)
|
||||
lora.mtime = os.path.getmtime(lora_on_disk.filename)
|
||||
|
||||
sd = sd_models.read_state_dict(lora_on_disk.filename)
|
||||
|
||||
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
||||
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
|
||||
assign_lora_names_to_compvis_modules(shared.sd_model)
|
||||
|
||||
keys_failed_to_match = {}
|
||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
||||
|
||||
for key_lora, weight in sd.items():
|
||||
key_lora_without_lora_parts, lora_key = key_lora.split(".", 1)
|
||||
|
||||
key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2)
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||
|
||||
if sd_module is None:
|
||||
m = re_x_proj.match(key)
|
||||
if m:
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
|
||||
|
||||
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
||||
if sd_module is None and "lora_unet" in key_lora_without_lora_parts:
|
||||
key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model")
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||
elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts:
|
||||
key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||
|
||||
if sd_module is None:
|
||||
keys_failed_to_match[key_lora] = key
|
||||
continue
|
||||
|
||||
lora_module = lora.modules.get(key, None)
|
||||
if lora_module is None:
|
||||
lora_module = LoraUpDownModule()
|
||||
lora.modules[key] = lora_module
|
||||
|
||||
if lora_key == "alpha":
|
||||
lora_module.alpha = weight.item()
|
||||
continue
|
||||
|
||||
if type(sd_module) == torch.nn.Linear:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(sd_module) == torch.nn.MultiheadAttention:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||
else:
|
||||
print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||
continue
|
||||
raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")
|
||||
|
||||
with torch.no_grad():
|
||||
module.weight.copy_(weight)
|
||||
|
||||
module.to(device=devices.cpu, dtype=devices.dtype)
|
||||
|
||||
if lora_key == "lora_up.weight":
|
||||
lora_module.up = module
|
||||
elif lora_key == "lora_down.weight":
|
||||
lora_module.down = module
|
||||
else:
|
||||
raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")
|
||||
|
||||
if keys_failed_to_match:
|
||||
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
|
||||
|
||||
return lora
|
||||
|
||||
|
||||
def load_loras(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
|
||||
for lora in loaded_loras:
|
||||
if lora.name in names:
|
||||
already_loaded[lora.name] = lora
|
||||
|
||||
loaded_loras.clear()
|
||||
|
||||
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||
if any(x is None for x in loras_on_disk):
|
||||
list_available_loras()
|
||||
|
||||
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||
|
||||
failed_to_load_loras = []
|
||||
|
||||
for i, name in enumerate(names):
|
||||
lora = already_loaded.get(name, None)
|
||||
|
||||
lora_on_disk = loras_on_disk[i]
|
||||
|
||||
if lora_on_disk is not None:
|
||||
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
||||
try:
|
||||
lora = load_lora(name, lora_on_disk)
|
||||
except Exception as e:
|
||||
errors.display(e, f"loading Lora {lora_on_disk.filename}")
|
||||
continue
|
||||
|
||||
lora.mentioned_name = name
|
||||
|
||||
lora_on_disk.read_hash()
|
||||
|
||||
if lora is None:
|
||||
failed_to_load_loras.append(name)
|
||||
print(f"Couldn't find Lora with name {name}")
|
||||
continue
|
||||
|
||||
lora.multiplier = multipliers[i] if multipliers else 1.0
|
||||
loaded_loras.append(lora)
|
||||
|
||||
if failed_to_load_loras:
|
||||
sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
|
||||
|
||||
|
||||
def lora_calc_updown(lora, module, target):
|
||||
with torch.no_grad():
|
||||
up = module.up.weight.to(target.device, dtype=target.dtype)
|
||||
down = module.down.weight.to(target.device, dtype=target.dtype)
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
|
||||
return updown
|
||||
|
||||
|
||||
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||
|
||||
if weights_backup is None:
|
||||
return
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.in_proj_weight.copy_(weights_backup[0])
|
||||
self.out_proj.weight.copy_(weights_backup[1])
|
||||
else:
|
||||
self.weight.copy_(weights_backup)
|
||||
|
||||
|
||||
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||
"""
|
||||
Applies the currently selected set of Loras to the weights of torch layer self.
|
||||
If weights already have this particular set of loras applied, does nothing.
|
||||
If not, restores orginal weights from backup and alters weights according to loras.
|
||||
"""
|
||||
|
||||
lora_layer_name = getattr(self, 'lora_layer_name', None)
|
||||
if lora_layer_name is None:
|
||||
return
|
||||
|
||||
current_names = getattr(self, "lora_current_names", ())
|
||||
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
|
||||
|
||||
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||
if weights_backup is None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
||||
else:
|
||||
weights_backup = self.weight.to(devices.cpu, copy=True)
|
||||
|
||||
self.lora_weights_backup = weights_backup
|
||||
|
||||
if current_names != wanted_names:
|
||||
lora_restore_weights_from_backup(self)
|
||||
|
||||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
self.weight += lora_calc_updown(lora, module, self.weight)
|
||||
continue
|
||||
|
||||
module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
|
||||
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
|
||||
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
|
||||
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
|
||||
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
|
||||
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
|
||||
continue
|
||||
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
||||
|
||||
self.lora_current_names = wanted_names
|
||||
|
||||
|
||||
def lora_forward(module, input, original_forward):
|
||||
"""
|
||||
Old way of applying Lora by executing operations during layer's forward.
|
||||
Stacking many loras this way results in big performance degradation.
|
||||
"""
|
||||
|
||||
if len(loaded_loras) == 0:
|
||||
return original_forward(module, input)
|
||||
|
||||
input = devices.cond_cast_unet(input)
|
||||
|
||||
lora_restore_weights_from_backup(module)
|
||||
lora_reset_cached_weight(module)
|
||||
|
||||
res = original_forward(module, input)
|
||||
|
||||
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
||||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
module.up.to(device=devices.device)
|
||||
module.down.to(device=devices.device)
|
||||
|
||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||
self.lora_current_names = ()
|
||||
self.lora_weights_backup = None
|
||||
|
||||
|
||||
def lora_Linear_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
||||
|
||||
lora_apply_weights(self)
|
||||
|
||||
return torch.nn.Linear_forward_before_lora(self, input)
|
||||
|
||||
|
||||
def lora_Linear_load_state_dict(self, *args, **kwargs):
|
||||
lora_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def lora_Conv2d_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
|
||||
|
||||
lora_apply_weights(self)
|
||||
|
||||
return torch.nn.Conv2d_forward_before_lora(self, input)
|
||||
|
||||
|
||||
def lora_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||
lora_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def lora_MultiheadAttention_forward(self, *args, **kwargs):
|
||||
lora_apply_weights(self)
|
||||
|
||||
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||
lora_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def list_available_loras():
|
||||
available_loras.clear()
|
||||
available_lora_aliases.clear()
|
||||
forbidden_lora_aliases.clear()
|
||||
available_lora_hash_lookup.clear()
|
||||
forbidden_lora_aliases.update({"none": 1, "Addams": 1})
|
||||
|
||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||
|
||||
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||
for filename in candidates:
|
||||
if os.path.isdir(filename):
|
||||
continue
|
||||
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
try:
|
||||
entry = LoraOnDisk(name, filename)
|
||||
except OSError: # should catch FileNotFoundError and PermissionError etc.
|
||||
errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)
|
||||
continue
|
||||
|
||||
available_loras[name] = entry
|
||||
|
||||
if entry.alias in available_lora_aliases:
|
||||
forbidden_lora_aliases[entry.alias.lower()] = 1
|
||||
|
||||
available_lora_aliases[name] = entry
|
||||
available_lora_aliases[entry.alias] = entry
|
||||
|
||||
|
||||
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||
|
||||
|
||||
def infotext_pasted(infotext, params):
|
||||
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||
|
||||
added = []
|
||||
|
||||
for k in params:
|
||||
if not k.startswith("AddNet Model "):
|
||||
continue
|
||||
|
||||
num = k[13:]
|
||||
|
||||
if params.get("AddNet Module " + num) != "LoRA":
|
||||
continue
|
||||
|
||||
name = params.get("AddNet Model " + num)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
m = re_lora_name.match(name)
|
||||
if m:
|
||||
name = m.group(1)
|
||||
|
||||
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||
|
||||
added.append(f"<lora:{name}:{multiplier}>")
|
||||
|
||||
if added:
|
||||
params["Prompt"] += "\n" + "".join(added)
|
||||
|
||||
|
||||
available_loras = {}
|
||||
available_lora_aliases = {}
|
||||
available_lora_hash_lookup = {}
|
||||
forbidden_lora_aliases = {}
|
||||
loaded_loras = []
|
||||
|
||||
list_available_loras()
|
15
extensions-builtin/Lora/lyco_helpers.py
Normal file
15
extensions-builtin/Lora/lyco_helpers.py
Normal file
@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
def make_weight_cp(t, wa, wb):
|
||||
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
||||
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
||||
|
||||
|
||||
def rebuild_conventional(up, down, shape, dyn_dim=None):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
if dyn_dim is not None:
|
||||
up = up[:, :dyn_dim]
|
||||
down = down[:dyn_dim, :]
|
||||
return (up @ down).reshape(shape)
|
98
extensions-builtin/Lora/network.py
Normal file
98
extensions-builtin/Lora/network.py
Normal file
@ -0,0 +1,98 @@
|
||||
import os
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from modules import devices, sd_models, cache, errors, hashes, shared
|
||||
|
||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
|
||||
class NetworkOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
self.metadata = {}
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
def read_metadata():
|
||||
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
|
||||
|
||||
return metadata
|
||||
|
||||
if self.is_safetensors:
|
||||
try:
|
||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading lora {filename}")
|
||||
|
||||
if self.metadata:
|
||||
m = {}
|
||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||
m[k] = v
|
||||
|
||||
self.metadata = m
|
||||
|
||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||
|
||||
self.hash = None
|
||||
self.shorthash = None
|
||||
self.set_hash(
|
||||
self.metadata.get('sshs_model_hash') or
|
||||
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
||||
''
|
||||
)
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v
|
||||
self.shorthash = self.hash[0:12]
|
||||
|
||||
if self.shorthash:
|
||||
import networks
|
||||
networks.available_network_hash_lookup[self.shorthash] = self
|
||||
|
||||
def read_hash(self):
|
||||
if not self.hash:
|
||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||
|
||||
def get_alias(self):
|
||||
import networks
|
||||
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
|
||||
return self.name
|
||||
else:
|
||||
return self.alias
|
||||
|
||||
|
||||
class Network: # LoraModule
|
||||
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
||||
self.name = name
|
||||
self.network_on_disk = network_on_disk
|
||||
self.multiplier = 1.0
|
||||
self.modules = {}
|
||||
self.mtime = None
|
||||
|
||||
self.mentioned_name = None
|
||||
"""the text that was used to add the network to prompt - can be either name or an alias"""
|
||||
|
||||
|
||||
class ModuleType:
|
||||
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModule:
|
||||
def __init__(self, net: Network, weights: NetworkWeights):
|
||||
self.network = net
|
||||
self.network_key = weights.network_key
|
||||
self.sd_key = weights.sd_key
|
||||
self.sd_module = weights.sd_module
|
||||
|
||||
def calc_updown(self, target):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, y):
|
||||
raise NotImplementedError()
|
||||
|
59
extensions-builtin/Lora/network_hada.py
Normal file
59
extensions-builtin/Lora/network_hada.py
Normal file
@ -0,0 +1,59 @@
|
||||
import lyco_helpers
|
||||
import network
|
||||
import network_lyco
|
||||
|
||||
|
||||
class ModuleTypeHada(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
|
||||
return NetworkModuleHada(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleHada(network_lyco.NetworkModuleLyco):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.w1a = weights.w["hada_w1_a"]
|
||||
self.w1b = weights.w["hada_w1_b"]
|
||||
self.dim = self.w1b.shape[0]
|
||||
self.w2a = weights.w["hada_w2_a"]
|
||||
self.w2b = weights.w["hada_w2_b"]
|
||||
|
||||
self.t1 = weights.w.get("hada_t1")
|
||||
self.t2 = weights.w.get("hada_t2")
|
||||
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
|
||||
if self.t1 is not None:
|
||||
output_shape = [w1a.size(1), w1b.size(1)]
|
||||
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||
output_shape += t1.shape[2:]
|
||||
else:
|
||||
if len(w1b.shape) == 4:
|
||||
output_shape += w1b.shape[2:]
|
||||
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||
|
||||
if self.t2 is not None:
|
||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
else:
|
||||
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||
|
||||
updown = updown1 * updown2
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
70
extensions-builtin/Lora/network_lora.py
Normal file
70
extensions-builtin/Lora/network_lora.py
Normal file
@ -0,0 +1,70 @@
|
||||
import torch
|
||||
|
||||
import network
|
||||
from modules import devices
|
||||
|
||||
|
||||
class ModuleTypeLora(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
|
||||
return NetworkModuleLora(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleLora(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.up = self.create_module(weights.w["lora_up.weight"])
|
||||
self.down = self.create_module(weights.w["lora_down.weight"])
|
||||
self.alpha = weights.w["alpha"] if "alpha" in weights.w else None
|
||||
|
||||
def create_module(self, weight, none_ok=False):
|
||||
if weight is None and none_ok:
|
||||
return None
|
||||
|
||||
if type(self.sd_module) == torch.nn.Linear:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(self.sd_module) == torch.nn.MultiheadAttention:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||
else:
|
||||
print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
|
||||
return None
|
||||
|
||||
with torch.no_grad():
|
||||
module.weight.copy_(weight)
|
||||
|
||||
module.to(device=devices.cpu, dtype=devices.dtype)
|
||||
module.weight.requires_grad_(False)
|
||||
|
||||
return module
|
||||
|
||||
def calc_updown(self, target):
|
||||
up = self.up.weight.to(target.device, dtype=target.dtype)
|
||||
down = self.down.weight.to(target.device, dtype=target.dtype)
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
|
||||
|
||||
return updown
|
||||
|
||||
def forward(self, x, y):
|
||||
self.up.to(device=devices.device)
|
||||
self.down.to(device=devices.device)
|
||||
|
||||
return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
|
||||
|
||||
|
39
extensions-builtin/Lora/network_lyco.py
Normal file
39
extensions-builtin/Lora/network_lyco.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch
|
||||
|
||||
import lyco_helpers
|
||||
import network
|
||||
from modules import devices
|
||||
|
||||
|
||||
class NetworkModuleLyco(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.dim = None
|
||||
self.bias = weights.w.get("bias")
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||
|
||||
def finalize_updown(self, updown, orig_weight, output_shape):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if len(output_shape) == 4:
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if orig_weight.size().numel() == updown.size().numel():
|
||||
updown = updown.reshape(orig_weight.shape)
|
||||
|
||||
scale = (
|
||||
self.scale if self.scale is not None
|
||||
else self.alpha / self.dim if self.dim is not None and self.alpha is not None
|
||||
else 1.0
|
||||
)
|
||||
|
||||
return updown * scale * self.network.multiplier
|
||||
|
443
extensions-builtin/Lora/networks.py
Normal file
443
extensions-builtin/Lora/networks.py
Normal file
@ -0,0 +1,443 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import network
|
||||
import network_lora
|
||||
import network_hada
|
||||
|
||||
import torch
|
||||
from typing import Union
|
||||
|
||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
||||
|
||||
module_types = [
|
||||
network_lora.ModuleTypeLora(),
|
||||
network_hada.ModuleTypeHada(),
|
||||
]
|
||||
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||
re_compiled = {}
|
||||
|
||||
suffix_conversion = {
|
||||
"attentions": {},
|
||||
"resnets": {
|
||||
"conv1": "in_layers_2",
|
||||
"conv2": "out_layers_3",
|
||||
"time_emb_proj": "emb_layers_1",
|
||||
"conv_shortcut": "skip_connection",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||
def match(match_list, regex_text):
|
||||
regex = re_compiled.get(regex_text)
|
||||
if regex is None:
|
||||
regex = re.compile(regex_text)
|
||||
re_compiled[regex_text] = regex
|
||||
|
||||
r = re.match(regex, key)
|
||||
if not r:
|
||||
return False
|
||||
|
||||
match_list.clear()
|
||||
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||
return True
|
||||
|
||||
m = []
|
||||
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
||||
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||
|
||||
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
||||
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
||||
|
||||
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
||||
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
||||
|
||||
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
if is_sd2:
|
||||
if 'mlp_fc1' in m[1]:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||
elif 'mlp_fc2' in m[1]:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||
else:
|
||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||
|
||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||
|
||||
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
if 'mlp_fc1' in m[1]:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||
elif 'mlp_fc2' in m[1]:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||
else:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||
|
||||
return key
|
||||
|
||||
|
||||
def assign_network_names_to_compvis_modules(sd_model):
|
||||
network_layer_mapping = {}
|
||||
|
||||
if shared.sd_model.is_sdxl:
|
||||
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
||||
if not hasattr(embedder, 'wrapped'):
|
||||
continue
|
||||
|
||||
for name, module in embedder.wrapped.named_modules():
|
||||
network_name = f'{i}_{name.replace(".", "_")}'
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
else:
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
network_name = name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
|
||||
for name, module in shared.sd_model.model.named_modules():
|
||||
network_name = name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
|
||||
sd_model.network_layer_mapping = network_layer_mapping
|
||||
|
||||
|
||||
def load_network(name, network_on_disk):
|
||||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
|
||||
sd = sd_models.read_state_dict(network_on_disk.filename)
|
||||
|
||||
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
||||
if not hasattr(shared.sd_model, 'network_layer_mapping'):
|
||||
assign_network_names_to_compvis_modules(shared.sd_model)
|
||||
|
||||
keys_failed_to_match = {}
|
||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
||||
|
||||
matched_networks = {}
|
||||
|
||||
for key_network, weight in sd.items():
|
||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
||||
|
||||
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
|
||||
if sd_module is None:
|
||||
m = re_x_proj.match(key)
|
||||
if m:
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
|
||||
|
||||
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
||||
if sd_module is None and "lora_unet" in key_network_without_network_parts:
|
||||
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
|
||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
|
||||
if sd_module is None:
|
||||
keys_failed_to_match[key_network] = key
|
||||
continue
|
||||
|
||||
if key not in matched_networks:
|
||||
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
|
||||
|
||||
matched_networks[key].w[network_part] = weight
|
||||
|
||||
for key, weights in matched_networks.items():
|
||||
net_module = None
|
||||
for nettype in module_types:
|
||||
net_module = nettype.create_module(net, weights)
|
||||
if net_module is not None:
|
||||
break
|
||||
|
||||
if net_module is None:
|
||||
raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
|
||||
|
||||
net.modules[key] = net_module
|
||||
|
||||
if keys_failed_to_match:
|
||||
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def load_networks(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
|
||||
for net in loaded_networks:
|
||||
if net.name in names:
|
||||
already_loaded[net.name] = net
|
||||
|
||||
loaded_networks.clear()
|
||||
|
||||
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
||||
if any(x is None for x in networks_on_disk):
|
||||
list_available_networks()
|
||||
|
||||
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
||||
|
||||
failed_to_load_networks = []
|
||||
|
||||
for i, name in enumerate(names):
|
||||
net = already_loaded.get(name, None)
|
||||
|
||||
network_on_disk = networks_on_disk[i]
|
||||
|
||||
if network_on_disk is not None:
|
||||
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
||||
try:
|
||||
net = load_network(name, network_on_disk)
|
||||
except Exception as e:
|
||||
errors.display(e, f"loading network {network_on_disk.filename}")
|
||||
continue
|
||||
|
||||
net.mentioned_name = name
|
||||
|
||||
network_on_disk.read_hash()
|
||||
|
||||
if net is None:
|
||||
failed_to_load_networks.append(name)
|
||||
print(f"Couldn't find network with name {name}")
|
||||
continue
|
||||
|
||||
net.multiplier = multipliers[i] if multipliers else 1.0
|
||||
loaded_networks.append(net)
|
||||
|
||||
if failed_to_load_networks:
|
||||
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
||||
|
||||
|
||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
|
||||
if weights_backup is None:
|
||||
return
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.in_proj_weight.copy_(weights_backup[0])
|
||||
self.out_proj.weight.copy_(weights_backup[1])
|
||||
else:
|
||||
self.weight.copy_(weights_backup)
|
||||
|
||||
|
||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||
"""
|
||||
Applies the currently selected set of networks to the weights of torch layer self.
|
||||
If weights already have this particular set of networks applied, does nothing.
|
||||
If not, restores orginal weights from backup and alters weights according to networks.
|
||||
"""
|
||||
|
||||
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||
if network_layer_name is None:
|
||||
return
|
||||
|
||||
current_names = getattr(self, "network_current_names", ())
|
||||
wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks)
|
||||
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
if weights_backup is None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
||||
else:
|
||||
weights_backup = self.weight.to(devices.cpu, copy=True)
|
||||
|
||||
self.network_weights_backup = weights_backup
|
||||
|
||||
if current_names != wanted_names:
|
||||
network_restore_weights_from_backup(self)
|
||||
|
||||
for net in loaded_networks:
|
||||
module = net.modules.get(network_layer_name, None)
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
with torch.no_grad():
|
||||
updown = module.calc_updown(self.weight)
|
||||
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
|
||||
self.weight += updown
|
||||
|
||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
||||
module_v = net.modules.get(network_layer_name + "_v_proj", None)
|
||||
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||
with torch.no_grad():
|
||||
updown_q = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += module_out.calc_updown(self.out_proj.weight)
|
||||
continue
|
||||
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
print(f'failed to calculate network weights for layer {network_layer_name}')
|
||||
|
||||
self.network_current_names = wanted_names
|
||||
|
||||
|
||||
def network_forward(module, input, original_forward):
|
||||
"""
|
||||
Old way of applying Lora by executing operations during layer's forward.
|
||||
Stacking many loras this way results in big performance degradation.
|
||||
"""
|
||||
|
||||
if len(loaded_networks) == 0:
|
||||
return original_forward(module, input)
|
||||
|
||||
input = devices.cond_cast_unet(input)
|
||||
|
||||
network_restore_weights_from_backup(module)
|
||||
network_reset_cached_weight(module)
|
||||
|
||||
y = original_forward(module, input)
|
||||
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
for lora in loaded_networks:
|
||||
module = lora.modules.get(network_layer_name, None)
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
y = module.forward(y, input)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||
self.network_current_names = ()
|
||||
self.network_weights_backup = None
|
||||
|
||||
|
||||
def network_Linear_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, torch.nn.Linear_forward_before_network)
|
||||
|
||||
network_apply_weights(self)
|
||||
|
||||
return torch.nn.Linear_forward_before_network(self, input)
|
||||
|
||||
|
||||
def network_Linear_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_Conv2d_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
|
||||
|
||||
network_apply_weights(self)
|
||||
|
||||
return torch.nn.Conv2d_forward_before_network(self, input)
|
||||
|
||||
|
||||
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||
network_apply_weights(self)
|
||||
|
||||
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
|
||||
|
||||
|
||||
def list_available_networks():
|
||||
available_networks.clear()
|
||||
available_network_aliases.clear()
|
||||
forbidden_network_aliases.clear()
|
||||
available_network_hash_lookup.clear()
|
||||
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
||||
|
||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||
|
||||
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||
for filename in candidates:
|
||||
if os.path.isdir(filename):
|
||||
continue
|
||||
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
try:
|
||||
entry = network.NetworkOnDisk(name, filename)
|
||||
except OSError: # should catch FileNotFoundError and PermissionError etc.
|
||||
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
|
||||
continue
|
||||
|
||||
available_networks[name] = entry
|
||||
|
||||
if entry.alias in available_network_aliases:
|
||||
forbidden_network_aliases[entry.alias.lower()] = 1
|
||||
|
||||
available_network_aliases[name] = entry
|
||||
available_network_aliases[entry.alias] = entry
|
||||
|
||||
|
||||
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||
|
||||
|
||||
def infotext_pasted(infotext, params):
|
||||
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||
|
||||
added = []
|
||||
|
||||
for k in params:
|
||||
if not k.startswith("AddNet Model "):
|
||||
continue
|
||||
|
||||
num = k[13:]
|
||||
|
||||
if params.get("AddNet Module " + num) != "LoRA":
|
||||
continue
|
||||
|
||||
name = params.get("AddNet Model " + num)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
m = re_network_name.match(name)
|
||||
if m:
|
||||
name = m.group(1)
|
||||
|
||||
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||
|
||||
added.append(f"<lora:{name}:{multiplier}>")
|
||||
|
||||
if added:
|
||||
params["Prompt"] += "\n" + "".join(added)
|
||||
|
||||
|
||||
available_networks = {}
|
||||
available_network_aliases = {}
|
||||
loaded_networks = []
|
||||
available_network_hash_lookup = {}
|
||||
forbidden_network_aliases = {}
|
||||
|
||||
list_available_networks()
|
@ -4,18 +4,19 @@ import torch
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
|
||||
import lora
|
||||
import network
|
||||
import networks
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
|
||||
def unload():
|
||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
|
||||
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
|
||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
|
||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
|
||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
|
||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
|
||||
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
|
||||
|
||||
|
||||
def before_ui():
|
||||
@ -23,50 +24,50 @@ def before_ui():
|
||||
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
|
||||
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
||||
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
||||
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
|
||||
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
|
||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
|
||||
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
||||
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
||||
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
|
||||
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
|
||||
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
|
||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
||||
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
|
||||
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
||||
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
||||
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
|
||||
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
|
||||
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
|
||||
|
||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
||||
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
|
||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
||||
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
|
||||
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
|
||||
torch.nn.Linear.forward = networks.network_Linear_forward
|
||||
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
||||
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
||||
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
||||
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
||||
|
||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||
script_callbacks.on_script_unloaded(unload)
|
||||
script_callbacks.on_before_ui(before_ui)
|
||||
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
||||
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
|
||||
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
|
||||
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||
}))
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
||||
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||
}))
|
||||
|
||||
|
||||
def create_lora_json(obj: lora.LoraOnDisk):
|
||||
def create_lora_json(obj: network.NetworkOnDisk):
|
||||
return {
|
||||
"name": obj.name,
|
||||
"alias": obj.alias,
|
||||
@ -75,17 +76,17 @@ def create_lora_json(obj: lora.LoraOnDisk):
|
||||
}
|
||||
|
||||
|
||||
def api_loras(_: gr.Blocks, app: FastAPI):
|
||||
def api_networks(_: gr.Blocks, app: FastAPI):
|
||||
@app.get("/sdapi/v1/loras")
|
||||
async def get_loras():
|
||||
return [create_lora_json(obj) for obj in lora.available_loras.values()]
|
||||
return [create_lora_json(obj) for obj in networks.available_networks.values()]
|
||||
|
||||
@app.post("/sdapi/v1/refresh-loras")
|
||||
async def refresh_loras():
|
||||
return lora.list_available_loras()
|
||||
return networks.list_available_networks()
|
||||
|
||||
|
||||
script_callbacks.on_app_started(api_loras)
|
||||
script_callbacks.on_app_started(api_networks)
|
||||
|
||||
re_lora = re.compile("<lora:([^:]+):")
|
||||
|
||||
@ -98,19 +99,19 @@ def infotext_pasted(infotext, d):
|
||||
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
|
||||
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
|
||||
|
||||
def lora_replacement(m):
|
||||
def network_replacement(m):
|
||||
alias = m.group(1)
|
||||
shorthash = hashes.get(alias)
|
||||
if shorthash is None:
|
||||
return m.group(0)
|
||||
|
||||
lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
|
||||
if lora_on_disk is None:
|
||||
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
|
||||
if network_on_disk is None:
|
||||
return m.group(0)
|
||||
|
||||
return f'<lora:{lora_on_disk.get_alias()}:'
|
||||
return f'<lora:{network_on_disk.get_alias()}:'
|
||||
|
||||
d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
|
||||
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
||||
|
||||
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
import lora
|
||||
import networks
|
||||
|
||||
from modules import shared, ui_extra_networks
|
||||
from modules.ui_extra_networks import quote_js
|
||||
@ -11,10 +11,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
super().__init__('Lora')
|
||||
|
||||
def refresh(self):
|
||||
lora.list_available_loras()
|
||||
networks.list_available_networks()
|
||||
|
||||
def create_item(self, name, index=None):
|
||||
lora_on_disk = lora.available_loras.get(name)
|
||||
lora_on_disk = networks.available_networks.get(name)
|
||||
|
||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||
|
||||
@ -43,7 +43,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
return item
|
||||
|
||||
def list_items(self):
|
||||
for index, name in enumerate(lora.available_loras):
|
||||
for index, name in enumerate(networks.available_networks):
|
||||
item = self.create_item(name, index)
|
||||
yield item
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user