add support for SDXL loras with te1/te2 modules
This commit is contained in:
parent
ff73841c60
commit
6c5f83b19b
@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
|
|||||||
|
|
||||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||||
|
|
||||||
|
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
|
||||||
|
if 'mlp_fc1' in m[1]:
|
||||||
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||||
|
elif 'mlp_fc2' in m[1]:
|
||||||
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||||
|
else:
|
||||||
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||||
|
|
||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
@ -142,10 +150,20 @@ class LoraUpDownModule:
|
|||||||
def assign_lora_names_to_compvis_modules(sd_model):
|
def assign_lora_names_to_compvis_modules(sd_model):
|
||||||
lora_layer_mapping = {}
|
lora_layer_mapping = {}
|
||||||
|
|
||||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
if shared.sd_model.is_sdxl:
|
||||||
lora_name = name.replace(".", "_")
|
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
||||||
lora_layer_mapping[lora_name] = module
|
if not hasattr(embedder, 'wrapped'):
|
||||||
module.lora_layer_name = lora_name
|
continue
|
||||||
|
|
||||||
|
for name, module in embedder.wrapped.named_modules():
|
||||||
|
lora_name = f'{i}_{name.replace(".", "_")}'
|
||||||
|
lora_layer_mapping[lora_name] = module
|
||||||
|
module.lora_layer_name = lora_name
|
||||||
|
else:
|
||||||
|
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||||
|
lora_name = name.replace(".", "_")
|
||||||
|
lora_layer_mapping[lora_name] = module
|
||||||
|
module.lora_layer_name = lora_name
|
||||||
|
|
||||||
for name, module in shared.sd_model.model.named_modules():
|
for name, module in shared.sd_model.model.named_modules():
|
||||||
lora_name = name.replace(".", "_")
|
lora_name = name.replace(".", "_")
|
||||||
@ -168,10 +186,10 @@ def load_lora(name, lora_on_disk):
|
|||||||
keys_failed_to_match = {}
|
keys_failed_to_match = {}
|
||||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
||||||
|
|
||||||
for key_diffusers, weight in sd.items():
|
for key_lora, weight in sd.items():
|
||||||
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
|
key_lora_without_lora_parts, lora_key = key_lora.split(".", 1)
|
||||||
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
|
|
||||||
|
|
||||||
|
key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2)
|
||||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||||
|
|
||||||
if sd_module is None:
|
if sd_module is None:
|
||||||
@ -180,12 +198,15 @@ def load_lora(name, lora_on_disk):
|
|||||||
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
|
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
|
||||||
|
|
||||||
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
||||||
if sd_module is None and "lora_unet" in key_diffusers_without_lora_parts:
|
if sd_module is None and "lora_unet" in key_lora_without_lora_parts:
|
||||||
key = key_diffusers_without_lora_parts.replace("lora_unet", "diffusion_model")
|
key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model")
|
||||||
|
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||||
|
elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts:
|
||||||
|
key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||||
|
|
||||||
if sd_module is None:
|
if sd_module is None:
|
||||||
keys_failed_to_match[key_diffusers] = key
|
keys_failed_to_match[key_lora] = key
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora_module = lora.modules.get(key, None)
|
lora_module = lora.modules.get(key, None)
|
||||||
|
@ -289,7 +289,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
if hasattr(model, 'conditioner'):
|
model.is_sdxl = hasattr(model, 'conditioner')
|
||||||
|
if model.is_sdxl:
|
||||||
sd_models_xl.extend_sdxl(model)
|
sd_models_xl.extend_sdxl(model)
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
@ -48,7 +48,6 @@ def extend_sdxl(model):
|
|||||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||||||
|
|
||||||
model.is_sdxl = True
|
|
||||||
|
|
||||||
|
|
||||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||||
|
Loading…
Reference in New Issue
Block a user