Merge latest sd-script updates
This commit is contained in:
parent
d37aa6efad
commit
2eddd64b90
12
README.md
12
README.md
@ -192,8 +192,20 @@ This will store your a backup file with your current locally installed pip packa
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
|
* 2023/04/01 (v21.4.0)
|
||||||
|
- Fix an issue that `merge_lora.py` does not work with the latest version.
|
||||||
|
- Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights.
|
||||||
|
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
|
||||||
|
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)
|
||||||
|
- Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev!
|
||||||
|
- See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details.
|
||||||
|
- Use `train_textual_inversion_XTI.py` for training. The usage is almost the same as `train_textual_inversion.py`. However, sample image generation during training is not supported.
|
||||||
|
- Use `gen_img_diffusers.py` for image generation (I think Web UI is not supported). Specify the embedding with `--XTI_embeddings` option.
|
||||||
|
- Reduce RAM usage at startup in `train_network.py`. [#332](https://github.com/kohya-ss/sd-scripts/pull/332) Thank you guaneec!
|
||||||
|
- Support pre-merge for LoRA in `gen_img_diffusers.py`. Specify `--network_merge` option. Note that the `--am` option of the prompt option is no longer available with this option.
|
||||||
* 2023/04/01 (v21.3.9)
|
* 2023/04/01 (v21.3.9)
|
||||||
- Update how setup is done on Windows by introducing a setup.bat script. This will make it easier to install/re-install on Windows if needed. Many thanks to @missionfloyd for his PR: https://github.com/bmaltais/kohya_ss/pull/496
|
- Update how setup is done on Windows by introducing a setup.bat script. This will make it easier to install/re-install on Windows if needed. Many thanks to @missionfloyd for his PR: https://github.com/bmaltais/kohya_ss/pull/496
|
||||||
|
- Fix issue with WD14 caption script by applying a custom fix to kohya_ss code.
|
||||||
* 2023/03/30 (v21.3.8)
|
* 2023/03/30 (v21.3.8)
|
||||||
- Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481
|
- Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481
|
||||||
* 2023/03/29 (v21.3.7)
|
* 2023/03/29 (v21.3.7)
|
||||||
|
209
XTI_hijack.py
Normal file
209
XTI_hijack.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||||
|
|
||||||
|
def unet_forward_XTI(self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||||
|
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||||
|
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||||
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||||
|
returning a tuple, the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||||
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||||
|
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||||
|
# on the fly if necessary.
|
||||||
|
default_overall_up_factor = 2**self.num_upsamplers
|
||||||
|
|
||||||
|
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||||
|
forward_upsample_size = False
|
||||||
|
upsample_size = None
|
||||||
|
|
||||||
|
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||||
|
logger.info("Forward upsample size to force interpolation output size.")
|
||||||
|
forward_upsample_size = True
|
||||||
|
|
||||||
|
# 0. center input if necessary
|
||||||
|
if self.config.center_input_sample:
|
||||||
|
sample = 2 * sample - 1.0
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
# This would be a good case for the `match` statement (Python 3.10+)
|
||||||
|
is_mps = sample.device.type == "mps"
|
||||||
|
if isinstance(timestep, float):
|
||||||
|
dtype = torch.float32 if is_mps else torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||||
|
elif len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=self.dtype)
|
||||||
|
emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
if self.config.num_class_embeds is not None:
|
||||||
|
if class_labels is None:
|
||||||
|
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||||
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||||
|
emb = emb + class_emb
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
down_i = 0
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
|
||||||
|
)
|
||||||
|
down_i += 2
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
|
||||||
|
|
||||||
|
# 5. up
|
||||||
|
up_i = 7
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
is_final_block = i == len(self.up_blocks) - 1
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||||
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||||
|
|
||||||
|
# if we have not reached the final block and need to forward the
|
||||||
|
# upsample size, we do it here
|
||||||
|
if not is_final_block and forward_upsample_size:
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
)
|
||||||
|
up_i += 3
|
||||||
|
else:
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||||
|
)
|
||||||
|
# 6. post-process
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
|
||||||
|
return UNet2DConditionOutput(sample=sample)
|
||||||
|
|
||||||
|
def downblock_forward_XTI(
|
||||||
|
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||||
|
):
|
||||||
|
output_states = ()
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
||||||
|
|
||||||
|
output_states += (hidden_states,)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
|
||||||
|
output_states += (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
def upblock_forward_XTI(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states_tuple,
|
||||||
|
temb=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
upsample_size=None,
|
||||||
|
):
|
||||||
|
i = 0
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
@ -95,6 +95,8 @@ import library.train_util as train_util
|
|||||||
import tools.original_control_net as original_control_net
|
import tools.original_control_net as original_control_net
|
||||||
from tools.original_control_net import ControlNetInfo
|
from tools.original_control_net import ControlNetInfo
|
||||||
|
|
||||||
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||||
@ -491,6 +493,9 @@ class PipelineLike:
|
|||||||
# Textual Inversion
|
# Textual Inversion
|
||||||
self.token_replacements = {}
|
self.token_replacements = {}
|
||||||
|
|
||||||
|
# XTI
|
||||||
|
self.token_replacements_XTI = {}
|
||||||
|
|
||||||
# CLIP guidance
|
# CLIP guidance
|
||||||
self.clip_guidance_scale = clip_guidance_scale
|
self.clip_guidance_scale = clip_guidance_scale
|
||||||
self.clip_image_guidance_scale = clip_image_guidance_scale
|
self.clip_image_guidance_scale = clip_image_guidance_scale
|
||||||
@ -514,15 +519,26 @@ class PipelineLike:
|
|||||||
def add_token_replacement(self, target_token_id, rep_token_ids):
|
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||||
self.token_replacements[target_token_id] = rep_token_ids
|
self.token_replacements[target_token_id] = rep_token_ids
|
||||||
|
|
||||||
def replace_token(self, tokens):
|
def replace_token(self, tokens, layer=None):
|
||||||
new_tokens = []
|
new_tokens = []
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
if token in self.token_replacements:
|
if token in self.token_replacements:
|
||||||
new_tokens.extend(self.token_replacements[token])
|
replacer_ = self.token_replacements[token]
|
||||||
|
if layer:
|
||||||
|
replacer = []
|
||||||
|
for r in replacer_:
|
||||||
|
if r in self.token_replacements_XTI:
|
||||||
|
replacer.append(self.token_replacements_XTI[r][layer])
|
||||||
|
else:
|
||||||
|
replacer = replacer_
|
||||||
|
new_tokens.extend(replacer)
|
||||||
else:
|
else:
|
||||||
new_tokens.append(token)
|
new_tokens.append(token)
|
||||||
return new_tokens
|
return new_tokens
|
||||||
|
|
||||||
|
def add_token_replacement_XTI(self, target_token_id, rep_token_ids):
|
||||||
|
self.token_replacements_XTI[target_token_id] = rep_token_ids
|
||||||
|
|
||||||
def set_control_nets(self, ctrl_nets):
|
def set_control_nets(self, ctrl_nets):
|
||||||
self.control_nets = ctrl_nets
|
self.control_nets = ctrl_nets
|
||||||
|
|
||||||
@ -744,14 +760,15 @@ class PipelineLike:
|
|||||||
" the batch size of `prompt`."
|
" the batch size of `prompt`."
|
||||||
)
|
)
|
||||||
|
|
||||||
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
|
if not self.token_replacements_XTI:
|
||||||
pipe=self,
|
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
|
||||||
prompt=prompt,
|
pipe=self,
|
||||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
prompt=prompt,
|
||||||
max_embeddings_multiples=max_embeddings_multiples,
|
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||||
clip_skip=self.clip_skip,
|
max_embeddings_multiples=max_embeddings_multiples,
|
||||||
**kwargs,
|
clip_skip=self.clip_skip,
|
||||||
)
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if negative_scale is not None:
|
if negative_scale is not None:
|
||||||
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
|
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
|
||||||
@ -763,11 +780,47 @@ class PipelineLike:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
if self.token_replacements_XTI:
|
||||||
if negative_scale is None:
|
text_embeddings_concat = []
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
for layer in [
|
||||||
else:
|
"IN01",
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
|
"IN02",
|
||||||
|
"IN04",
|
||||||
|
"IN05",
|
||||||
|
"IN07",
|
||||||
|
"IN08",
|
||||||
|
"MID",
|
||||||
|
"OUT03",
|
||||||
|
"OUT04",
|
||||||
|
"OUT05",
|
||||||
|
"OUT06",
|
||||||
|
"OUT07",
|
||||||
|
"OUT08",
|
||||||
|
"OUT09",
|
||||||
|
"OUT10",
|
||||||
|
"OUT11",
|
||||||
|
]:
|
||||||
|
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
|
||||||
|
pipe=self,
|
||||||
|
prompt=prompt,
|
||||||
|
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||||
|
max_embeddings_multiples=max_embeddings_multiples,
|
||||||
|
clip_skip=self.clip_skip,
|
||||||
|
layer=layer,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
if negative_scale is None:
|
||||||
|
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings]))
|
||||||
|
else:
|
||||||
|
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]))
|
||||||
|
text_embeddings = torch.stack(text_embeddings_concat)
|
||||||
|
else:
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
if negative_scale is None:
|
||||||
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||||
|
else:
|
||||||
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
|
||||||
|
|
||||||
# CLIP guidanceで使用するembeddingsを取得する
|
# CLIP guidanceで使用するembeddingsを取得する
|
||||||
if self.clip_guidance_scale > 0:
|
if self.clip_guidance_scale > 0:
|
||||||
@ -1675,7 +1728,7 @@ def parse_prompt_attention(text):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int):
|
def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None):
|
||||||
r"""
|
r"""
|
||||||
Tokenize a list of prompts and return its tokens with weights of each token.
|
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||||
No padding, starting or ending token is included.
|
No padding, starting or ending token is included.
|
||||||
@ -1691,7 +1744,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
|||||||
# tokenize and discard the starting and the ending token
|
# tokenize and discard the starting and the ending token
|
||||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||||
|
|
||||||
token = pipe.replace_token(token)
|
token = pipe.replace_token(token, layer=layer)
|
||||||
|
|
||||||
text_token += token
|
text_token += token
|
||||||
# copy the weight by length of token
|
# copy the weight by length of token
|
||||||
@ -1807,6 +1860,7 @@ def get_weighted_text_embeddings(
|
|||||||
skip_parsing: Optional[bool] = False,
|
skip_parsing: Optional[bool] = False,
|
||||||
skip_weighting: Optional[bool] = False,
|
skip_weighting: Optional[bool] = False,
|
||||||
clip_skip=None,
|
clip_skip=None,
|
||||||
|
layer=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@ -1837,11 +1891,11 @@ def get_weighted_text_embeddings(
|
|||||||
prompt = [prompt]
|
prompt = [prompt]
|
||||||
|
|
||||||
if not skip_parsing:
|
if not skip_parsing:
|
||||||
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
|
||||||
if uncond_prompt is not None:
|
if uncond_prompt is not None:
|
||||||
if isinstance(uncond_prompt, str):
|
if isinstance(uncond_prompt, str):
|
||||||
uncond_prompt = [uncond_prompt]
|
uncond_prompt = [uncond_prompt]
|
||||||
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer)
|
||||||
else:
|
else:
|
||||||
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
||||||
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
||||||
@ -2229,13 +2283,17 @@ def main(args):
|
|||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
network.apply_to(text_encoder, unet)
|
if not args.network_merge:
|
||||||
|
network.apply_to(text_encoder, unet)
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
network.to(memory_format=torch.channels_last)
|
network.to(memory_format=torch.channels_last)
|
||||||
network.to(dtype).to(device)
|
network.to(dtype).to(device)
|
||||||
|
|
||||||
|
networks.append(network)
|
||||||
|
else:
|
||||||
|
network.merge_to(text_encoder, unet, dtype, device)
|
||||||
|
|
||||||
networks.append(network)
|
|
||||||
else:
|
else:
|
||||||
networks = []
|
networks = []
|
||||||
|
|
||||||
@ -2289,6 +2347,11 @@ def main(args):
|
|||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
if args.XTI_embeddings:
|
||||||
|
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||||
|
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||||
|
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||||
|
|
||||||
# Textual Inversionを処理する
|
# Textual Inversionを処理する
|
||||||
if args.textual_inversion_embeddings:
|
if args.textual_inversion_embeddings:
|
||||||
token_ids_embeds = []
|
token_ids_embeds = []
|
||||||
@ -2335,6 +2398,71 @@ def main(args):
|
|||||||
for token_id, embed in zip(token_ids, embeds):
|
for token_id, embed in zip(token_ids, embeds):
|
||||||
token_embeds[token_id] = embed
|
token_embeds[token_id] = embed
|
||||||
|
|
||||||
|
if args.XTI_embeddings:
|
||||||
|
XTI_layers = [
|
||||||
|
"IN01",
|
||||||
|
"IN02",
|
||||||
|
"IN04",
|
||||||
|
"IN05",
|
||||||
|
"IN07",
|
||||||
|
"IN08",
|
||||||
|
"MID",
|
||||||
|
"OUT03",
|
||||||
|
"OUT04",
|
||||||
|
"OUT05",
|
||||||
|
"OUT06",
|
||||||
|
"OUT07",
|
||||||
|
"OUT08",
|
||||||
|
"OUT09",
|
||||||
|
"OUT10",
|
||||||
|
"OUT11",
|
||||||
|
]
|
||||||
|
token_ids_embeds_XTI = []
|
||||||
|
for embeds_file in args.XTI_embeddings:
|
||||||
|
if model_util.is_safetensors(embeds_file):
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
data = load_file(embeds_file)
|
||||||
|
else:
|
||||||
|
data = torch.load(embeds_file, map_location="cpu")
|
||||||
|
if set(data.keys()) != set(XTI_layers):
|
||||||
|
raise ValueError("NOT XTI")
|
||||||
|
embeds = torch.concat(list(data.values()))
|
||||||
|
num_vectors_per_token = data["MID"].size()[0]
|
||||||
|
|
||||||
|
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||||
|
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||||
|
|
||||||
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
|
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||||
|
assert (
|
||||||
|
num_added_tokens == num_vectors_per_token
|
||||||
|
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||||
|
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
|
print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
||||||
|
|
||||||
|
# if num_vectors_per_token > 1:
|
||||||
|
pipe.add_token_replacement(token_ids[0], token_ids)
|
||||||
|
|
||||||
|
token_strings_XTI = []
|
||||||
|
for layer_name in XTI_layers:
|
||||||
|
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
|
||||||
|
tokenizer.add_tokens(token_strings_XTI)
|
||||||
|
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
|
||||||
|
token_ids_embeds_XTI.append((token_ids_XTI, embeds))
|
||||||
|
for t in token_ids:
|
||||||
|
t_XTI_dic = {}
|
||||||
|
for i, layer_name in enumerate(XTI_layers):
|
||||||
|
t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens
|
||||||
|
pipe.add_token_replacement_XTI(t, t_XTI_dic)
|
||||||
|
|
||||||
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||||
|
for token_ids, embeds in token_ids_embeds_XTI:
|
||||||
|
for token_id, embed in zip(token_ids, embeds):
|
||||||
|
token_embeds[token_id] = embed
|
||||||
|
|
||||||
# promptを取得する
|
# promptを取得する
|
||||||
if args.from_file is not None:
|
if args.from_file is not None:
|
||||||
print(f"reading prompts from {args.from_file}")
|
print(f"reading prompts from {args.from_file}")
|
||||||
@ -2983,6 +3111,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
|
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
|
||||||
)
|
)
|
||||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||||
|
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--textual_inversion_embeddings",
|
"--textual_inversion_embeddings",
|
||||||
type=str,
|
type=str,
|
||||||
@ -2990,6 +3119,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
nargs="*",
|
nargs="*",
|
||||||
help="Embeddings files of Textual Inversion / Textual Inversionのembeddings",
|
help="Embeddings files of Textual Inversion / Textual Inversionのembeddings",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--XTI_embeddings",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
nargs="*",
|
||||||
|
help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings",
|
||||||
|
)
|
||||||
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
|
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_embeddings_multiples",
|
"--max_embeddings_multiples",
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -404,6 +404,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.token_padding_disabled = False
|
self.token_padding_disabled = False
|
||||||
self.tag_frequency = {}
|
self.tag_frequency = {}
|
||||||
|
self.XTI_layers = None
|
||||||
|
self.token_strings = None
|
||||||
|
|
||||||
self.enable_bucket = False
|
self.enable_bucket = False
|
||||||
self.bucket_manager: BucketManager = None # not initialized
|
self.bucket_manager: BucketManager = None # not initialized
|
||||||
@ -464,6 +466,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
def disable_token_padding(self):
|
def disable_token_padding(self):
|
||||||
self.token_padding_disabled = True
|
self.token_padding_disabled = True
|
||||||
|
|
||||||
|
def enable_XTI(self, layers=None, token_strings=None):
|
||||||
|
self.XTI_layers = layers
|
||||||
|
self.token_strings = token_strings
|
||||||
|
|
||||||
def add_replacement(self, str_from, str_to):
|
def add_replacement(self, str_from, str_to):
|
||||||
self.replacements[str_from] = str_to
|
self.replacements[str_from] = str_to
|
||||||
|
|
||||||
@ -909,9 +915,22 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
latents_list.append(latents)
|
latents_list.append(latents)
|
||||||
|
|
||||||
caption = self.process_caption(subset, image_info.caption)
|
caption = self.process_caption(subset, image_info.caption)
|
||||||
captions.append(caption)
|
if self.XTI_layers:
|
||||||
|
caption_layer = []
|
||||||
|
for layer in self.XTI_layers:
|
||||||
|
token_strings_from = " ".join(self.token_strings)
|
||||||
|
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||||
|
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||||
|
caption_layer.append(caption_)
|
||||||
|
captions.append(caption_layer)
|
||||||
|
else:
|
||||||
|
captions.append(caption)
|
||||||
if not self.token_padding_disabled: # this option might be omitted in future
|
if not self.token_padding_disabled: # this option might be omitted in future
|
||||||
input_ids_list.append(self.get_input_ids(caption))
|
if self.XTI_layers:
|
||||||
|
token_caption = self.get_input_ids(caption_layer)
|
||||||
|
else:
|
||||||
|
token_caption = self.get_input_ids(caption)
|
||||||
|
input_ids_list.append(token_caption)
|
||||||
|
|
||||||
example = {}
|
example = {}
|
||||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||||
@ -1314,6 +1333,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
# for dataset in self.datasets:
|
# for dataset in self.datasets:
|
||||||
# dataset.make_buckets()
|
# dataset.make_buckets()
|
||||||
|
|
||||||
|
def enable_XTI(self, *args, **kwargs):
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.enable_XTI(*args, **kwargs)
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1):
|
def cache_latents(self, vae, vae_batch_size=1):
|
||||||
for i, dataset in enumerate(self.datasets):
|
for i, dataset in enumerate(self.datasets):
|
||||||
print(f"[Dataset {i}]")
|
print(f"[Dataset {i}]")
|
||||||
@ -2617,14 +2640,15 @@ def prepare_dtype(args: argparse.Namespace):
|
|||||||
return weight_dtype, save_dtype
|
return weight_dtype, save_dtype
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args: argparse.Namespace, weight_dtype):
|
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
|
||||||
name_or_path = args.pretrained_model_name_or_path
|
name_or_path = args.pretrained_model_name_or_path
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
print("load StableDiffusion checkpoint")
|
print("load StableDiffusion checkpoint")
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
|
||||||
else:
|
else:
|
||||||
|
# Diffusers model is loaded to CPU
|
||||||
print("load Diffusers pretrained models")
|
print("load Diffusers pretrained models")
|
||||||
try:
|
try:
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
||||||
|
679
networks/lora.py
679
networks/lora.py
@ -13,386 +13,471 @@ from library import train_util
|
|||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lora_name = lora_name
|
self.lora_name = lora_name
|
||||||
|
|
||||||
if org_module.__class__.__name__ == 'Conv2d':
|
if org_module.__class__.__name__ == "Conv2d":
|
||||||
in_dim = org_module.in_channels
|
in_dim = org_module.in_channels
|
||||||
out_dim = org_module.out_channels
|
out_dim = org_module.out_channels
|
||||||
else:
|
else:
|
||||||
in_dim = org_module.in_features
|
in_dim = org_module.in_features
|
||||||
out_dim = org_module.out_features
|
out_dim = org_module.out_features
|
||||||
|
|
||||||
# if limit_rank:
|
# if limit_rank:
|
||||||
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
||||||
# if self.lora_dim != lora_dim:
|
# if self.lora_dim != lora_dim:
|
||||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||||
# else:
|
# else:
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
|
|
||||||
if org_module.__class__.__name__ == 'Conv2d':
|
if org_module.__class__.__name__ == "Conv2d":
|
||||||
kernel_size = org_module.kernel_size
|
kernel_size = org_module.kernel_size
|
||||||
stride = org_module.stride
|
stride = org_module.stride
|
||||||
padding = org_module.padding
|
padding = org_module.padding
|
||||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||||
else:
|
else:
|
||||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||||
|
|
||||||
if type(alpha) == torch.Tensor:
|
if type(alpha) == torch.Tensor:
|
||||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||||
self.scale = alpha / self.lora_dim
|
self.scale = alpha / self.lora_dim
|
||||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||||
|
|
||||||
# same as microsoft's
|
# same as microsoft's
|
||||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||||
torch.nn.init.zeros_(self.lora_up.weight)
|
torch.nn.init.zeros_(self.lora_up.weight)
|
||||||
|
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.org_module = org_module # remove in applying
|
self.org_module = org_module # remove in applying
|
||||||
self.region = None
|
self.region = None
|
||||||
self.region_mask = None
|
self.region_mask = None
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
self.org_module.forward = self.forward
|
self.org_module.forward = self.forward
|
||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
def set_region(self, region):
|
def merge_to(self, sd, dtype, device):
|
||||||
self.region = region
|
# get up/down weight
|
||||||
self.region_mask = None
|
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||||
|
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
||||||
|
|
||||||
def forward(self, x):
|
# extract weight from org_module
|
||||||
if self.region is None:
|
org_sd = self.org_module.state_dict()
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
weight = org_sd["weight"].to(torch.float)
|
||||||
|
|
||||||
# regional LoRA FIXME same as additional-network extension
|
# merge weight
|
||||||
if x.size()[1] % 77 == 0:
|
if len(weight.size()) == 2:
|
||||||
# print(f"LoRA for context: {self.lora_name}")
|
# linear
|
||||||
self.region = None
|
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
|
# conv2d 1x1
|
||||||
|
weight = (
|
||||||
|
weight
|
||||||
|
+ self.multiplier
|
||||||
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
* self.scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# conv2d 3x3
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
|
weight = weight + self.multiplier * conved * self.scale
|
||||||
|
|
||||||
# calculate region mask first time
|
# set weight to org_module
|
||||||
if self.region_mask is None:
|
org_sd["weight"] = weight.to(dtype)
|
||||||
if len(x.size()) == 4:
|
self.org_module.load_state_dict(org_sd)
|
||||||
h, w = x.size()[2:4]
|
|
||||||
else:
|
|
||||||
seq_len = x.size()[1]
|
|
||||||
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
|
||||||
h = int(self.region.size()[0] / ratio + .5)
|
|
||||||
w = seq_len // h
|
|
||||||
|
|
||||||
r = self.region.to(x.device)
|
def set_region(self, region):
|
||||||
if r.dtype == torch.bfloat16:
|
self.region = region
|
||||||
r = r.to(torch.float)
|
self.region_mask = None
|
||||||
r = r.unsqueeze(0).unsqueeze(1)
|
|
||||||
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
|
||||||
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
|
|
||||||
r = r.to(x.dtype)
|
|
||||||
|
|
||||||
if len(x.size()) == 3:
|
def forward(self, x):
|
||||||
r = torch.reshape(r, (1, x.size()[1], -1))
|
if self.region is None:
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
self.region_mask = r
|
# regional LoRA FIXME same as additional-network extension
|
||||||
|
if x.size()[1] % 77 == 0:
|
||||||
|
# print(f"LoRA for context: {self.lora_name}")
|
||||||
|
self.region = None
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
# calculate region mask first time
|
||||||
|
if self.region_mask is None:
|
||||||
|
if len(x.size()) == 4:
|
||||||
|
h, w = x.size()[2:4]
|
||||||
|
else:
|
||||||
|
seq_len = x.size()[1]
|
||||||
|
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
||||||
|
h = int(self.region.size()[0] / ratio + 0.5)
|
||||||
|
w = seq_len // h
|
||||||
|
|
||||||
|
r = self.region.to(x.device)
|
||||||
|
if r.dtype == torch.bfloat16:
|
||||||
|
r = r.to(torch.float)
|
||||||
|
r = r.unsqueeze(0).unsqueeze(1)
|
||||||
|
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
||||||
|
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
|
||||||
|
r = r.to(x.dtype)
|
||||||
|
|
||||||
|
if len(x.size()) == 3:
|
||||||
|
r = torch.reshape(r, (1, x.size()[1], -1))
|
||||||
|
|
||||||
|
self.region_mask = r
|
||||||
|
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
network_dim = 4 # default
|
network_dim = 4 # default
|
||||||
|
|
||||||
# extract dim/alpha for conv2d, and block dim
|
# extract dim/alpha for conv2d, and block dim
|
||||||
conv_dim = kwargs.get('conv_dim', None)
|
conv_dim = kwargs.get("conv_dim", None)
|
||||||
conv_alpha = kwargs.get('conv_alpha', None)
|
conv_alpha = kwargs.get("conv_alpha", None)
|
||||||
if conv_dim is not None:
|
if conv_dim is not None:
|
||||||
conv_dim = int(conv_dim)
|
conv_dim = int(conv_dim)
|
||||||
if conv_alpha is None:
|
if conv_alpha is None:
|
||||||
conv_alpha = 1.0
|
conv_alpha = 1.0
|
||||||
else:
|
else:
|
||||||
conv_alpha = float(conv_alpha)
|
conv_alpha = float(conv_alpha)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
block_dims = kwargs.get("block_dims")
|
block_dims = kwargs.get("block_dims")
|
||||||
block_alphas = None
|
block_alphas = None
|
||||||
|
|
||||||
if block_dims is not None:
|
if block_dims is not None:
|
||||||
block_dims = [int(d) for d in block_dims.split(',')]
|
block_dims = [int(d) for d in block_dims.split(',')]
|
||||||
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
||||||
block_alphas = kwargs.get("block_alphas")
|
block_alphas = kwargs.get("block_alphas")
|
||||||
if block_alphas is None:
|
if block_alphas is None:
|
||||||
block_alphas = [1] * len(block_dims)
|
block_alphas = [1] * len(block_dims)
|
||||||
else:
|
else:
|
||||||
block_alphas = [int(a) for a in block_alphas(',')]
|
block_alphas = [int(a) for a in block_alphas(',')]
|
||||||
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
||||||
|
|
||||||
conv_block_dims = kwargs.get("conv_block_dims")
|
conv_block_dims = kwargs.get("conv_block_dims")
|
||||||
conv_block_alphas = None
|
conv_block_alphas = None
|
||||||
|
|
||||||
if conv_block_dims is not None:
|
if conv_block_dims is not None:
|
||||||
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
||||||
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
||||||
conv_block_alphas = kwargs.get("conv_block_alphas")
|
conv_block_alphas = kwargs.get("conv_block_alphas")
|
||||||
if conv_block_alphas is None:
|
if conv_block_alphas is None:
|
||||||
conv_block_alphas = [1] * len(conv_block_dims)
|
conv_block_alphas = [1] * len(conv_block_dims)
|
||||||
else:
|
else:
|
||||||
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
||||||
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
|
network = LoRANetwork(
|
||||||
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
|
text_encoder,
|
||||||
return network
|
unet,
|
||||||
|
multiplier=multiplier,
|
||||||
|
lora_dim=network_dim,
|
||||||
|
alpha=network_alpha,
|
||||||
|
conv_lora_dim=conv_dim,
|
||||||
|
conv_alpha=conv_alpha,
|
||||||
|
)
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
||||||
if weights_sd is None:
|
if weights_sd is None:
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file, safe_open
|
from safetensors.torch import load_file, safe_open
|
||||||
weights_sd = load_file(file)
|
|
||||||
else:
|
|
||||||
weights_sd = torch.load(file, map_location='cpu')
|
|
||||||
|
|
||||||
# get dim/alpha mapping
|
weights_sd = load_file(file)
|
||||||
modules_dim = {}
|
else:
|
||||||
modules_alpha = {}
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
for key, value in weights_sd.items():
|
|
||||||
if '.' not in key:
|
|
||||||
continue
|
|
||||||
|
|
||||||
lora_name = key.split('.')[0]
|
# get dim/alpha mapping
|
||||||
if 'alpha' in key:
|
modules_dim = {}
|
||||||
modules_alpha[lora_name] = value
|
modules_alpha = {}
|
||||||
elif 'lora_down' in key:
|
for key, value in weights_sd.items():
|
||||||
dim = value.size()[0]
|
if "." not in key:
|
||||||
modules_dim[lora_name] = dim
|
continue
|
||||||
# print(lora_name, value.size(), dim)
|
|
||||||
|
|
||||||
# support old LoRA without alpha
|
lora_name = key.split(".")[0]
|
||||||
for key in modules_dim.keys():
|
if "alpha" in key:
|
||||||
if key not in modules_alpha:
|
modules_alpha[lora_name] = value
|
||||||
modules_alpha = modules_dim[key]
|
elif "lora_down" in key:
|
||||||
|
dim = value.size()[0]
|
||||||
|
modules_dim[lora_name] = dim
|
||||||
|
# print(lora_name, value.size(), dim)
|
||||||
|
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
# support old LoRA without alpha
|
||||||
network.weights_sd = weights_sd
|
for key in modules_dim.keys():
|
||||||
return network
|
if key not in modules_alpha:
|
||||||
|
modules_alpha = modules_dim[key]
|
||||||
|
|
||||||
|
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||||
|
network.weights_sd = weights_sd
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
class LoRANetwork(torch.nn.Module):
|
class LoRANetwork(torch.nn.Module):
|
||||||
# is it possible to apply conv_in and conv_out?
|
# is it possible to apply conv_in and conv_out?
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_UNET = 'lora_unet'
|
LORA_PREFIX_UNET = "lora_unet"
|
||||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||||
|
|
||||||
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None:
|
def __init__(
|
||||||
super().__init__()
|
self,
|
||||||
self.multiplier = multiplier
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
multiplier=1.0,
|
||||||
|
lora_dim=4,
|
||||||
|
alpha=1,
|
||||||
|
conv_lora_dim=None,
|
||||||
|
conv_alpha=None,
|
||||||
|
modules_dim=None,
|
||||||
|
modules_alpha=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.conv_lora_dim = conv_lora_dim
|
self.conv_lora_dim = conv_lora_dim
|
||||||
self.conv_alpha = conv_alpha
|
self.conv_alpha = conv_alpha
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
print(f"create LoRA network from weights")
|
print(f"create LoRA network from weights")
|
||||||
else:
|
else:
|
||||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
|
|
||||||
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
||||||
if self.apply_to_conv2d_3x3:
|
if self.apply_to_conv2d_3x3:
|
||||||
if self.conv_alpha is None:
|
if self.conv_alpha is None:
|
||||||
self.conv_alpha = self.alpha
|
self.conv_alpha = self.alpha
|
||||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||||
loras = []
|
loras = []
|
||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
# TODO get block index here
|
# TODO get block index here
|
||||||
for child_name, child_module in module.named_modules():
|
for child_name, child_module in module.named_modules():
|
||||||
is_linear = child_module.__class__.__name__ == "Linear"
|
is_linear = child_module.__class__.__name__ == "Linear"
|
||||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||||
if is_linear or is_conv2d:
|
if is_linear or is_conv2d:
|
||||||
lora_name = prefix + '.' + name + '.' + child_name
|
lora_name = prefix + "." + name + "." + child_name
|
||||||
lora_name = lora_name.replace('.', '_')
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
if lora_name not in modules_dim:
|
if lora_name not in modules_dim:
|
||||||
continue # no LoRA module in this weights file
|
continue # no LoRA module in this weights file
|
||||||
dim = modules_dim[lora_name]
|
dim = modules_dim[lora_name]
|
||||||
alpha = modules_alpha[lora_name]
|
alpha = modules_alpha[lora_name]
|
||||||
else:
|
else:
|
||||||
if is_linear or is_conv2d_1x1:
|
if is_linear or is_conv2d_1x1:
|
||||||
dim = self.lora_dim
|
dim = self.lora_dim
|
||||||
alpha = self.alpha
|
alpha = self.alpha
|
||||||
elif self.apply_to_conv2d_3x3:
|
elif self.apply_to_conv2d_3x3:
|
||||||
dim = self.conv_lora_dim
|
dim = self.conv_lora_dim
|
||||||
alpha = self.conv_alpha
|
alpha = self.conv_alpha
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
|
self.text_encoder_loras = create_modules(
|
||||||
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
)
|
||||||
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||||
if modules_dim is not None or self.conv_lora_dim is not None:
|
if modules_dim is not None or self.conv_lora_dim is not None:
|
||||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
||||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
self.weights_sd = None
|
self.weights_sd = None
|
||||||
|
|
||||||
# assertion
|
# assertion
|
||||||
names = set()
|
names = set()
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
lora.multiplier = self.multiplier
|
lora.multiplier = self.multiplier
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file, safe_open
|
from safetensors.torch import load_file, safe_open
|
||||||
self.weights_sd = load_file(file)
|
|
||||||
else:
|
|
||||||
self.weights_sd = torch.load(file, map_location='cpu')
|
|
||||||
|
|
||||||
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
self.weights_sd = load_file(file)
|
||||||
if self.weights_sd:
|
else:
|
||||||
weights_has_text_encoder = weights_has_unet = False
|
self.weights_sd = torch.load(file, map_location="cpu")
|
||||||
for key in self.weights_sd.keys():
|
|
||||||
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
|
||||||
weights_has_text_encoder = True
|
|
||||||
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
|
||||||
weights_has_unet = True
|
|
||||||
|
|
||||||
if apply_text_encoder is None:
|
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
||||||
apply_text_encoder = weights_has_text_encoder
|
if self.weights_sd:
|
||||||
else:
|
weights_has_text_encoder = weights_has_unet = False
|
||||||
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
for key in self.weights_sd.keys():
|
||||||
|
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||||
|
weights_has_text_encoder = True
|
||||||
|
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
||||||
|
weights_has_unet = True
|
||||||
|
|
||||||
if apply_unet is None:
|
if apply_text_encoder is None:
|
||||||
apply_unet = weights_has_unet
|
apply_text_encoder = weights_has_text_encoder
|
||||||
else:
|
else:
|
||||||
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
assert (
|
||||||
else:
|
apply_text_encoder == weights_has_text_encoder
|
||||||
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
||||||
|
|
||||||
if apply_text_encoder:
|
if apply_unet is None:
|
||||||
print("enable LoRA for text encoder")
|
apply_unet = weights_has_unet
|
||||||
else:
|
else:
|
||||||
self.text_encoder_loras = []
|
assert (
|
||||||
|
apply_unet == weights_has_unet
|
||||||
|
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
||||||
|
else:
|
||||||
|
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
||||||
|
|
||||||
if apply_unet:
|
if apply_text_encoder:
|
||||||
print("enable LoRA for U-Net")
|
print("enable LoRA for text encoder")
|
||||||
else:
|
else:
|
||||||
self.unet_loras = []
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
if apply_unet:
|
||||||
lora.apply_to()
|
print("enable LoRA for U-Net")
|
||||||
self.add_module(lora.lora_name, lora)
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
if self.weights_sd:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
lora.apply_to()
|
||||||
info = self.load_state_dict(self.weights_sd, False)
|
self.add_module(lora.lora_name, lora)
|
||||||
print(f"weights are loaded: {info}")
|
|
||||||
|
|
||||||
def enable_gradient_checkpointing(self):
|
if self.weights_sd:
|
||||||
# not supported
|
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
||||||
pass
|
info = self.load_state_dict(self.weights_sd, False)
|
||||||
|
print(f"weights are loaded: {info}")
|
||||||
|
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
# TODO refactor to common function with apply_to
|
||||||
def enumerate_params(loras):
|
def merge_to(self, text_encoder, unet, dtype, device):
|
||||||
params = []
|
assert self.weights_sd is not None, "weights are not loaded"
|
||||||
for lora in loras:
|
|
||||||
params.extend(lora.parameters())
|
|
||||||
return params
|
|
||||||
|
|
||||||
self.requires_grad_(True)
|
apply_text_encoder = apply_unet = False
|
||||||
all_params = []
|
for key in self.weights_sd.keys():
|
||||||
|
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||||
|
apply_text_encoder = True
|
||||||
|
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
||||||
|
apply_unet = True
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if apply_text_encoder:
|
||||||
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
print("enable LoRA for text encoder")
|
||||||
if text_encoder_lr is not None:
|
else:
|
||||||
param_data['lr'] = text_encoder_lr
|
self.text_encoder_loras = []
|
||||||
all_params.append(param_data)
|
|
||||||
|
|
||||||
if self.unet_loras:
|
if apply_unet:
|
||||||
param_data = {'params': enumerate_params(self.unet_loras)}
|
print("enable LoRA for U-Net")
|
||||||
if unet_lr is not None:
|
else:
|
||||||
param_data['lr'] = unet_lr
|
self.unet_loras = []
|
||||||
all_params.append(param_data)
|
|
||||||
|
|
||||||
return all_params
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
sd_for_lora = {}
|
||||||
|
for key in self.weights_sd.keys():
|
||||||
|
if key.startswith(lora.lora_name):
|
||||||
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
|
||||||
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
print(f"weights are merged")
|
||||||
|
|
||||||
def prepare_grad_etc(self, text_encoder, unet):
|
def enable_gradient_checkpointing(self):
|
||||||
self.requires_grad_(True)
|
# not supported
|
||||||
|
pass
|
||||||
|
|
||||||
def on_epoch_start(self, text_encoder, unet):
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
||||||
self.train()
|
def enumerate_params(loras):
|
||||||
|
params = []
|
||||||
|
for lora in loras:
|
||||||
|
params.extend(lora.parameters())
|
||||||
|
return params
|
||||||
|
|
||||||
def get_trainable_params(self):
|
self.requires_grad_(True)
|
||||||
return self.parameters()
|
all_params = []
|
||||||
|
|
||||||
def save_weights(self, file, dtype, metadata):
|
if self.text_encoder_loras:
|
||||||
if metadata is not None and len(metadata) == 0:
|
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||||
metadata = None
|
if text_encoder_lr is not None:
|
||||||
|
param_data["lr"] = text_encoder_lr
|
||||||
|
all_params.append(param_data)
|
||||||
|
|
||||||
state_dict = self.state_dict()
|
if self.unet_loras:
|
||||||
|
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||||
|
if unet_lr is not None:
|
||||||
|
param_data["lr"] = unet_lr
|
||||||
|
all_params.append(param_data)
|
||||||
|
|
||||||
if dtype is not None:
|
return all_params
|
||||||
for key in list(state_dict.keys()):
|
|
||||||
v = state_dict[key]
|
|
||||||
v = v.detach().clone().to("cpu").to(dtype)
|
|
||||||
state_dict[key] = v
|
|
||||||
|
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
def prepare_grad_etc(self, text_encoder, unet):
|
||||||
from safetensors.torch import save_file
|
self.requires_grad_(True)
|
||||||
|
|
||||||
# Precalculate model hashes to save time on indexing
|
def on_epoch_start(self, text_encoder, unet):
|
||||||
if metadata is None:
|
self.train()
|
||||||
metadata = {}
|
|
||||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
|
||||||
metadata["sshs_model_hash"] = model_hash
|
|
||||||
metadata["sshs_legacy_hash"] = legacy_hash
|
|
||||||
|
|
||||||
save_file(state_dict, file, metadata)
|
def get_trainable_params(self):
|
||||||
else:
|
return self.parameters()
|
||||||
torch.save(state_dict, file)
|
|
||||||
|
|
||||||
@ staticmethod
|
def save_weights(self, file, dtype, metadata):
|
||||||
def set_regions(networks, image):
|
if metadata is not None and len(metadata) == 0:
|
||||||
image = image.astype(np.float32) / 255.0
|
metadata = None
|
||||||
for i, network in enumerate(networks[:3]):
|
|
||||||
# NOTE: consider averaging overwrapping area
|
|
||||||
region = image[:, :, i]
|
|
||||||
if region.max() == 0:
|
|
||||||
continue
|
|
||||||
region = torch.tensor(region)
|
|
||||||
network.set_region(region)
|
|
||||||
|
|
||||||
def set_region(self, region):
|
state_dict = self.state_dict()
|
||||||
for lora in self.unet_loras:
|
|
||||||
lora.set_region(region)
|
if dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
# Precalculate model hashes to save time on indexing
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
|
save_file(state_dict, file, metadata)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_regions(networks, image):
|
||||||
|
image = image.astype(np.float32) / 255.0
|
||||||
|
for i, network in enumerate(networks[:3]):
|
||||||
|
# NOTE: consider averaging overwrapping area
|
||||||
|
region = image[:, :, i]
|
||||||
|
if region.max() == 0:
|
||||||
|
continue
|
||||||
|
region = torch.tensor(region)
|
||||||
|
network.set_region(region)
|
||||||
|
|
||||||
|
def set_region(self, region):
|
||||||
|
for lora in self.unet_loras:
|
||||||
|
lora.set_region(region)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
@ -9,216 +8,236 @@ import lora
|
|||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
sd = load_file(file_name)
|
sd = load_file(file_name)
|
||||||
else:
|
else:
|
||||||
sd = torch.load(file_name, map_location='cpu')
|
sd = torch.load(file_name, map_location="cpu")
|
||||||
for key in list(sd.keys()):
|
for key in list(sd.keys()):
|
||||||
if type(sd[key]) == torch.Tensor:
|
if type(sd[key]) == torch.Tensor:
|
||||||
sd[key] = sd[key].to(dtype)
|
sd[key] = sd[key].to(dtype)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, model, state_dict, dtype):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
if type(state_dict[key]) == torch.Tensor:
|
if type(state_dict[key]) == torch.Tensor:
|
||||||
state_dict[key] = state_dict[key].to(dtype)
|
state_dict[key] = state_dict[key].to(dtype)
|
||||||
|
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
save_file(model, file_name)
|
save_file(model, file_name)
|
||||||
else:
|
else:
|
||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||||
text_encoder.to(merge_dtype)
|
text_encoder.to(merge_dtype)
|
||||||
unet.to(merge_dtype)
|
unet.to(merge_dtype)
|
||||||
|
|
||||||
# create module map
|
# create module map
|
||||||
name_to_module = {}
|
name_to_module = {}
|
||||||
for i, root_module in enumerate([text_encoder, unet]):
|
for i, root_module in enumerate([text_encoder, unet]):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||||
else:
|
|
||||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
|
||||||
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
|
||||||
|
|
||||||
for name, module in root_module.named_modules():
|
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
|
||||||
for child_name, child_module in module.named_modules():
|
|
||||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
|
||||||
lora_name = prefix + '.' + name + '.' + child_name
|
|
||||||
lora_name = lora_name.replace('.', '_')
|
|
||||||
name_to_module[lora_name] = child_module
|
|
||||||
|
|
||||||
for model, ratio in zip(models, ratios):
|
|
||||||
print(f"loading: {model}")
|
|
||||||
lora_sd = load_state_dict(model, merge_dtype)
|
|
||||||
|
|
||||||
print(f"merging...")
|
|
||||||
for key in lora_sd.keys():
|
|
||||||
if "lora_down" in key:
|
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
|
||||||
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
|
||||||
|
|
||||||
# find original module for this lora
|
|
||||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
|
||||||
if module_name not in name_to_module:
|
|
||||||
print(f"no module found for LoRA weight: {key}")
|
|
||||||
continue
|
|
||||||
module = name_to_module[module_name]
|
|
||||||
# print(f"apply {key} to {module}")
|
|
||||||
|
|
||||||
down_weight = lora_sd[key]
|
|
||||||
up_weight = lora_sd[up_key]
|
|
||||||
|
|
||||||
dim = down_weight.size()[0]
|
|
||||||
alpha = lora_sd.get(alpha_key, dim)
|
|
||||||
scale = alpha / dim
|
|
||||||
|
|
||||||
# W <- W + U * D
|
|
||||||
weight = module.weight
|
|
||||||
# print(module_name, down_weight.size(), up_weight.size())
|
|
||||||
if len(weight.size()) == 2:
|
|
||||||
# linear
|
|
||||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
|
||||||
elif down_weight.size()[2:4] == (1, 1):
|
|
||||||
# conv2d 1x1
|
|
||||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
|
||||||
).unsqueeze(2).unsqueeze(3) * scale
|
|
||||||
else:
|
else:
|
||||||
# conv2d 3x3
|
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
target_replace_modules = (
|
||||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
weight = weight + ratio * conved * scale
|
)
|
||||||
|
|
||||||
module.weight = torch.nn.Parameter(weight)
|
for name, module in root_module.named_modules():
|
||||||
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||||
|
lora_name = prefix + "." + name + "." + child_name
|
||||||
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
name_to_module[lora_name] = child_module
|
||||||
|
|
||||||
|
for model, ratio in zip(models, ratios):
|
||||||
|
print(f"loading: {model}")
|
||||||
|
lora_sd = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
|
print(f"merging...")
|
||||||
|
for key in lora_sd.keys():
|
||||||
|
if "lora_down" in key:
|
||||||
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
|
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||||
|
|
||||||
|
# find original module for this lora
|
||||||
|
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||||
|
if module_name not in name_to_module:
|
||||||
|
print(f"no module found for LoRA weight: {key}")
|
||||||
|
continue
|
||||||
|
module = name_to_module[module_name]
|
||||||
|
# print(f"apply {key} to {module}")
|
||||||
|
|
||||||
|
down_weight = lora_sd[key]
|
||||||
|
up_weight = lora_sd[up_key]
|
||||||
|
|
||||||
|
dim = down_weight.size()[0]
|
||||||
|
alpha = lora_sd.get(alpha_key, dim)
|
||||||
|
scale = alpha / dim
|
||||||
|
|
||||||
|
# W <- W + U * D
|
||||||
|
weight = module.weight
|
||||||
|
# print(module_name, down_weight.size(), up_weight.size())
|
||||||
|
if len(weight.size()) == 2:
|
||||||
|
# linear
|
||||||
|
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||||
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
|
# conv2d 1x1
|
||||||
|
weight = (
|
||||||
|
weight
|
||||||
|
+ ratio
|
||||||
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
* scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# conv2d 3x3
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
|
weight = weight + ratio * conved * scale
|
||||||
|
|
||||||
|
module.weight = torch.nn.Parameter(weight)
|
||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, merge_dtype):
|
def merge_lora_models(models, ratios, merge_dtype):
|
||||||
base_alphas = {} # alpha for merged model
|
base_alphas = {} # alpha for merged model
|
||||||
base_dims = {}
|
base_dims = {}
|
||||||
|
|
||||||
merged_sd = {}
|
merged_sd = {}
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
print(f"loading: {model}")
|
||||||
lora_sd = load_state_dict(model, merge_dtype)
|
lora_sd = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
# get alpha and dim
|
# get alpha and dim
|
||||||
alphas = {} # alpha for current model
|
alphas = {} # alpha for current model
|
||||||
dims = {} # dims for current model
|
dims = {} # dims for current model
|
||||||
for key in lora_sd.keys():
|
for key in lora_sd.keys():
|
||||||
if 'alpha' in key:
|
if "alpha" in key:
|
||||||
lora_module_name = key[:key.rfind(".alpha")]
|
lora_module_name = key[: key.rfind(".alpha")]
|
||||||
alpha = float(lora_sd[key].detach().numpy())
|
alpha = float(lora_sd[key].detach().numpy())
|
||||||
alphas[lora_module_name] = alpha
|
alphas[lora_module_name] = alpha
|
||||||
if lora_module_name not in base_alphas:
|
if lora_module_name not in base_alphas:
|
||||||
base_alphas[lora_module_name] = alpha
|
base_alphas[lora_module_name] = alpha
|
||||||
elif "lora_down" in key:
|
elif "lora_down" in key:
|
||||||
lora_module_name = key[:key.rfind(".lora_down")]
|
lora_module_name = key[: key.rfind(".lora_down")]
|
||||||
dim = lora_sd[key].size()[0]
|
dim = lora_sd[key].size()[0]
|
||||||
dims[lora_module_name] = dim
|
dims[lora_module_name] = dim
|
||||||
if lora_module_name not in base_dims:
|
if lora_module_name not in base_dims:
|
||||||
base_dims[lora_module_name] = dim
|
base_dims[lora_module_name] = dim
|
||||||
|
|
||||||
for lora_module_name in dims.keys():
|
for lora_module_name in dims.keys():
|
||||||
if lora_module_name not in alphas:
|
if lora_module_name not in alphas:
|
||||||
alpha = dims[lora_module_name]
|
alpha = dims[lora_module_name]
|
||||||
alphas[lora_module_name] = alpha
|
alphas[lora_module_name] = alpha
|
||||||
if lora_module_name not in base_alphas:
|
if lora_module_name not in base_alphas:
|
||||||
base_alphas[lora_module_name] = alpha
|
base_alphas[lora_module_name] = alpha
|
||||||
|
|
||||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||||
|
|
||||||
# merge
|
# merge
|
||||||
print(f"merging...")
|
print(f"merging...")
|
||||||
for key in lora_sd.keys():
|
for key in lora_sd.keys():
|
||||||
if 'alpha' in key:
|
if "alpha" in key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora_module_name = key[:key.rfind(".lora_")]
|
lora_module_name = key[: key.rfind(".lora_")]
|
||||||
|
|
||||||
base_alpha = base_alphas[lora_module_name]
|
base_alpha = base_alphas[lora_module_name]
|
||||||
alpha = alphas[lora_module_name]
|
alpha = alphas[lora_module_name]
|
||||||
|
|
||||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||||
|
|
||||||
if key in merged_sd:
|
if key in merged_sd:
|
||||||
assert merged_sd[key].size() == lora_sd[key].size(
|
assert (
|
||||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
merged_sd[key].size() == lora_sd[key].size()
|
||||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||||
else:
|
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||||
merged_sd[key] = lora_sd[key] * scale
|
else:
|
||||||
|
merged_sd[key] = lora_sd[key] * scale
|
||||||
|
|
||||||
# set alpha to sd
|
# set alpha to sd
|
||||||
for lora_module_name, alpha in base_alphas.items():
|
for lora_module_name, alpha in base_alphas.items():
|
||||||
key = lora_module_name + ".alpha"
|
key = lora_module_name + ".alpha"
|
||||||
merged_sd[key] = torch.tensor(alpha)
|
merged_sd[key] = torch.tensor(alpha)
|
||||||
|
|
||||||
print("merged model")
|
print("merged model")
|
||||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||||
|
|
||||||
return merged_sd
|
return merged_sd
|
||||||
|
|
||||||
|
|
||||||
def merge(args):
|
def merge(args):
|
||||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||||
|
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == 'float':
|
if p == "float":
|
||||||
return torch.float
|
return torch.float
|
||||||
if p == 'fp16':
|
if p == "fp16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
if p == 'bf16':
|
if p == "bf16":
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
return None
|
return None
|
||||||
|
|
||||||
merge_dtype = str_to_dtype(args.precision)
|
merge_dtype = str_to_dtype(args.precision)
|
||||||
save_dtype = str_to_dtype(args.save_precision)
|
save_dtype = str_to_dtype(args.save_precision)
|
||||||
if save_dtype is None:
|
if save_dtype is None:
|
||||||
save_dtype = merge_dtype
|
save_dtype = merge_dtype
|
||||||
|
|
||||||
if args.sd_model is not None:
|
if args.sd_model is not None:
|
||||||
print(f"loading SD model: {args.sd_model}")
|
print(f"loading SD model: {args.sd_model}")
|
||||||
|
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||||
|
|
||||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||||
|
|
||||||
print(f"saving SD model to: {args.save_to}")
|
print(f"saving SD model to: {args.save_to}")
|
||||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae)
|
||||||
args.sd_model, 0, 0, save_dtype, vae)
|
else:
|
||||||
else:
|
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||||
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
print(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--v2", action='store_true',
|
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
||||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
parser.add_argument(
|
||||||
parser.add_argument("--save_precision", type=str, default=None,
|
"--save_precision",
|
||||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
type=str,
|
||||||
parser.add_argument("--precision", type=str, default="float",
|
default=None,
|
||||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
choices=[None, "float", "fp16", "bf16"],
|
||||||
parser.add_argument("--sd_model", type=str, default=None,
|
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
)
|
||||||
parser.add_argument("--save_to", type=str, default=None,
|
parser.add_argument(
|
||||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
"--precision",
|
||||||
parser.add_argument("--models", type=str, nargs='*',
|
type=str,
|
||||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
default="float",
|
||||||
parser.add_argument("--ratios", type=float, nargs='*',
|
choices=["float", "fp16", "bf16"],
|
||||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sd_model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||||
|
)
|
||||||
|
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
merge(args)
|
merge(args)
|
||||||
|
80
tools/merge_lycoris.py
Normal file
80
tools/merge_lycoris.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from lycoris.utils import merge_loha, merge_locon
|
||||||
|
from lycoris.kohya_model_utils import (
|
||||||
|
load_models_from_stable_diffusion_checkpoint,
|
||||||
|
save_stable_diffusion_checkpoint,
|
||||||
|
load_file
|
||||||
|
)
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
|
def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, weight):
|
||||||
|
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
|
||||||
|
if lycoris_model.rsplit('.', 1)[-1] == 'safetensors':
|
||||||
|
lyco = load_file(lycoris_model)
|
||||||
|
else:
|
||||||
|
lyco = torch.load(lycoris_model)
|
||||||
|
|
||||||
|
algo = None
|
||||||
|
for key in lyco:
|
||||||
|
if 'hada' in key:
|
||||||
|
algo = 'loha'
|
||||||
|
break
|
||||||
|
elif 'lora_up' in key:
|
||||||
|
algo = 'lora'
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Cannot find the algo for this lycoris model file.')
|
||||||
|
|
||||||
|
dtype_str = dtype.replace('fp', 'float').replace('bf', 'bfloat')
|
||||||
|
dtype = {
|
||||||
|
'float': torch.float,
|
||||||
|
'float16': torch.float16,
|
||||||
|
'float32': torch.float32,
|
||||||
|
'float64': torch.float64,
|
||||||
|
'bfloat': torch.bfloat16,
|
||||||
|
'bfloat16': torch.bfloat16,
|
||||||
|
}.get(dtype_str, None)
|
||||||
|
if dtype is None:
|
||||||
|
raise ValueError(f'Cannot Find the dtype "{dtype}"')
|
||||||
|
|
||||||
|
if algo == 'loha':
|
||||||
|
merge_loha(base, lyco, weight, device)
|
||||||
|
elif algo == 'lora':
|
||||||
|
merge_locon(base, lyco, weight, device)
|
||||||
|
|
||||||
|
save_stable_diffusion_checkpoint(
|
||||||
|
is_v2, output_name,
|
||||||
|
base[0], base[2],
|
||||||
|
None, 0, 0, dtype,
|
||||||
|
base[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_name
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
iface = gr.Interface(
|
||||||
|
fn=merge_models,
|
||||||
|
inputs=[
|
||||||
|
gr.inputs.Textbox(label="Base Model Path"),
|
||||||
|
gr.inputs.Textbox(label="Lycoris Model Path"),
|
||||||
|
gr.inputs.Textbox(label="Output Model Path", default='./out.pt'),
|
||||||
|
gr.inputs.Checkbox(label="Is base model SD V2?", default=False),
|
||||||
|
gr.inputs.Textbox(label="Device", default='cpu'),
|
||||||
|
gr.inputs.Dropdown(choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'], label="Dtype", default='float'),
|
||||||
|
gr.inputs.Number(label="Weight", default=1.0)
|
||||||
|
],
|
||||||
|
outputs=gr.outputs.Textbox(label="Merged Model Path"),
|
||||||
|
title="Model Merger",
|
||||||
|
description="Merge Lycoris and Stable Diffusion models",
|
||||||
|
)
|
||||||
|
|
||||||
|
iface.launch()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -25,7 +25,7 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
@ -127,12 +127,25 @@ def train(args):
|
|||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
for pi in range(accelerator.state.num_processes):
|
||||||
|
# TODO: modify other training scripts as well
|
||||||
|
if pi == accelerator.state.local_process_index:
|
||||||
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
|
|
||||||
|
text_encoder, vae, unet, _ = train_util.load_target_model(
|
||||||
|
args, weight_dtype, accelerator.device if args.lowram else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# work on low-ram device
|
||||||
|
if args.lowram:
|
||||||
|
text_encoder.to(accelerator.device)
|
||||||
|
unet.to(accelerator.device)
|
||||||
|
vae.to(accelerator.device)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# work on low-ram device
|
|
||||||
if args.lowram:
|
|
||||||
text_encoder.to("cuda")
|
|
||||||
unet.to("cuda")
|
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
@ -189,7 +202,7 @@ def train(args):
|
|||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
@ -556,9 +569,9 @@ def train(args):
|
|||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
644
train_textual_inversion_XTI.py
Normal file
644
train_textual_inversion_XTI.py
Normal file
@ -0,0 +1,644 @@
|
|||||||
|
import importlib
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
import diffusers
|
||||||
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
|
import library.train_util as train_util
|
||||||
|
import library.config_util as config_util
|
||||||
|
from library.config_util import (
|
||||||
|
ConfigSanitizer,
|
||||||
|
BlueprintGenerator,
|
||||||
|
)
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
|
|
||||||
|
imagenet_templates_small = [
|
||||||
|
"a photo of a {}",
|
||||||
|
"a rendering of a {}",
|
||||||
|
"a cropped photo of the {}",
|
||||||
|
"the photo of a {}",
|
||||||
|
"a photo of a clean {}",
|
||||||
|
"a photo of a dirty {}",
|
||||||
|
"a dark photo of the {}",
|
||||||
|
"a photo of my {}",
|
||||||
|
"a photo of the cool {}",
|
||||||
|
"a close-up photo of a {}",
|
||||||
|
"a bright photo of the {}",
|
||||||
|
"a cropped photo of a {}",
|
||||||
|
"a photo of the {}",
|
||||||
|
"a good photo of the {}",
|
||||||
|
"a photo of one {}",
|
||||||
|
"a close-up photo of the {}",
|
||||||
|
"a rendition of the {}",
|
||||||
|
"a photo of the clean {}",
|
||||||
|
"a rendition of a {}",
|
||||||
|
"a photo of a nice {}",
|
||||||
|
"a good photo of a {}",
|
||||||
|
"a photo of the nice {}",
|
||||||
|
"a photo of the small {}",
|
||||||
|
"a photo of the weird {}",
|
||||||
|
"a photo of the large {}",
|
||||||
|
"a photo of a cool {}",
|
||||||
|
"a photo of a small {}",
|
||||||
|
]
|
||||||
|
|
||||||
|
imagenet_style_templates_small = [
|
||||||
|
"a painting in the style of {}",
|
||||||
|
"a rendering in the style of {}",
|
||||||
|
"a cropped painting in the style of {}",
|
||||||
|
"the painting in the style of {}",
|
||||||
|
"a clean painting in the style of {}",
|
||||||
|
"a dirty painting in the style of {}",
|
||||||
|
"a dark painting in the style of {}",
|
||||||
|
"a picture in the style of {}",
|
||||||
|
"a cool painting in the style of {}",
|
||||||
|
"a close-up painting in the style of {}",
|
||||||
|
"a bright painting in the style of {}",
|
||||||
|
"a cropped painting in the style of {}",
|
||||||
|
"a good painting in the style of {}",
|
||||||
|
"a close-up painting in the style of {}",
|
||||||
|
"a rendition in the style of {}",
|
||||||
|
"a nice painting in the style of {}",
|
||||||
|
"a small painting in the style of {}",
|
||||||
|
"a weird painting in the style of {}",
|
||||||
|
"a large painting in the style of {}",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
if args.output_name is None:
|
||||||
|
args.output_name = args.token_string
|
||||||
|
use_template = args.use_object_template or args.use_style_template
|
||||||
|
|
||||||
|
train_util.verify_training_args(args)
|
||||||
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
|
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
||||||
|
print(
|
||||||
|
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_latents = args.cache_latents
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
|
||||||
|
# acceleratorを準備する
|
||||||
|
print("prepare accelerator")
|
||||||
|
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
|
# モデルを読み込む
|
||||||
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
||||||
|
|
||||||
|
# Convert the init_word to token_id
|
||||||
|
if args.init_word is not None:
|
||||||
|
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||||
|
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
||||||
|
print(
|
||||||
|
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
init_token_ids = None
|
||||||
|
|
||||||
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
|
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||||
|
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||||
|
assert (
|
||||||
|
num_added_tokens == args.num_vectors_per_token
|
||||||
|
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
||||||
|
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
|
print(f"tokens are added: {token_ids}")
|
||||||
|
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||||
|
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||||
|
|
||||||
|
token_strings_XTI = []
|
||||||
|
XTI_layers = [
|
||||||
|
"IN01",
|
||||||
|
"IN02",
|
||||||
|
"IN04",
|
||||||
|
"IN05",
|
||||||
|
"IN07",
|
||||||
|
"IN08",
|
||||||
|
"MID",
|
||||||
|
"OUT03",
|
||||||
|
"OUT04",
|
||||||
|
"OUT05",
|
||||||
|
"OUT06",
|
||||||
|
"OUT07",
|
||||||
|
"OUT08",
|
||||||
|
"OUT09",
|
||||||
|
"OUT10",
|
||||||
|
"OUT11",
|
||||||
|
]
|
||||||
|
for layer_name in XTI_layers:
|
||||||
|
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
|
||||||
|
|
||||||
|
tokenizer.add_tokens(token_strings_XTI)
|
||||||
|
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
|
||||||
|
print(f"tokens are added (XTI): {token_ids_XTI}")
|
||||||
|
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||||
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||||
|
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||||
|
if init_token_ids is not None:
|
||||||
|
for i, token_id in enumerate(token_ids_XTI):
|
||||||
|
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
|
||||||
|
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
|
|
||||||
|
# load weights
|
||||||
|
if args.weights is not None:
|
||||||
|
embeddings = load_weights(args.weights)
|
||||||
|
assert len(token_ids) == len(
|
||||||
|
embeddings
|
||||||
|
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
||||||
|
# print(token_ids, embeddings.size())
|
||||||
|
for token_id, embedding in zip(token_ids_XTI, embeddings):
|
||||||
|
token_embeds[token_id] = embedding
|
||||||
|
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
|
print(f"weighs loaded")
|
||||||
|
|
||||||
|
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||||
|
|
||||||
|
# データセットを準備する
|
||||||
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
||||||
|
if args.dataset_config is not None:
|
||||||
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||||
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
|
print(
|
||||||
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
|
", ".join(ignored)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
use_dreambooth_method = args.in_json is None
|
||||||
|
if use_dreambooth_method:
|
||||||
|
print("Use DreamBooth method.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print("Train with captions.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": [
|
||||||
|
{
|
||||||
|
"image_dir": args.train_data_dir,
|
||||||
|
"metadata_file": args.in_json,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
||||||
|
current_epoch = Value("i", 0)
|
||||||
|
current_step = Value("i", 0)
|
||||||
|
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||||
|
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||||
|
|
||||||
|
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||||
|
if use_template:
|
||||||
|
print("use template for training captions. is object: {args.use_object_template}")
|
||||||
|
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
||||||
|
replace_to = " ".join(token_strings)
|
||||||
|
captions = []
|
||||||
|
for tmpl in templates:
|
||||||
|
captions.append(tmpl.format(replace_to))
|
||||||
|
train_dataset_group.add_replacement("", captions)
|
||||||
|
|
||||||
|
if args.num_vectors_per_token > 1:
|
||||||
|
prompt_replacement = (args.token_string, replace_to)
|
||||||
|
else:
|
||||||
|
prompt_replacement = None
|
||||||
|
else:
|
||||||
|
if args.num_vectors_per_token > 1:
|
||||||
|
replace_to = " ".join(token_strings)
|
||||||
|
train_dataset_group.add_replacement(args.token_string, replace_to)
|
||||||
|
prompt_replacement = (args.token_string, replace_to)
|
||||||
|
else:
|
||||||
|
prompt_replacement = None
|
||||||
|
|
||||||
|
if args.debug_dataset:
|
||||||
|
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
||||||
|
return
|
||||||
|
if len(train_dataset_group) == 0:
|
||||||
|
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
||||||
|
return
|
||||||
|
|
||||||
|
if cache_latents:
|
||||||
|
assert (
|
||||||
|
train_dataset_group.is_latent_cacheable()
|
||||||
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
|
|
||||||
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
|
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||||
|
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||||
|
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||||
|
|
||||||
|
# 学習を準備する
|
||||||
|
if cache_latents:
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||||
|
vae.to("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# 学習に必要なクラスを準備する
|
||||||
|
print("prepare optimizer, data loader etc.")
|
||||||
|
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||||
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
|
# dataloaderを準備する
|
||||||
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset_group,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=collater,
|
||||||
|
num_workers=n_workers,
|
||||||
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 学習ステップ数を計算する
|
||||||
|
if args.max_train_epochs is not None:
|
||||||
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
|
# lr schedulerを用意する
|
||||||
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
|
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||||
|
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
|
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|
||||||
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
|
text_encoder.requires_grad_(True)
|
||||||
|
text_encoder.text_model.encoder.requires_grad_(False)
|
||||||
|
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||||
|
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
||||||
|
|
||||||
|
unet.requires_grad_(False)
|
||||||
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||||
|
unet.train()
|
||||||
|
else:
|
||||||
|
unet.eval()
|
||||||
|
|
||||||
|
if not cache_latents:
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
|
if args.full_fp16:
|
||||||
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
|
# resumeする
|
||||||
|
if args.resume is not None:
|
||||||
|
print(f"resume training from state: {args.resume}")
|
||||||
|
accelerator.load_state(args.resume)
|
||||||
|
|
||||||
|
# epoch数を計算する
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||||
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||||
|
|
||||||
|
# 学習する
|
||||||
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
print("running training / 学習開始")
|
||||||
|
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
|
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
|
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
|
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
|
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
|
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
noise_scheduler = DDPMScheduler(
|
||||||
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
accelerator.init_trackers("textual_inversion")
|
||||||
|
|
||||||
|
for epoch in range(num_train_epochs):
|
||||||
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
|
text_encoder.train()
|
||||||
|
|
||||||
|
loss_total = 0
|
||||||
|
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
|
with accelerator.accumulate(text_encoder):
|
||||||
|
with torch.no_grad():
|
||||||
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
else:
|
||||||
|
# latentに変換
|
||||||
|
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
|
latents = latents * 0.18215
|
||||||
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
input_ids = batch["input_ids"].to(accelerator.device)
|
||||||
|
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
||||||
|
encoder_hidden_states = torch.stack(
|
||||||
|
[
|
||||||
|
train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
|
||||||
|
for s in torch.split(input_ids, 1, dim=1)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample noise that we'll add to the latents
|
||||||
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
|
if args.noise_offset:
|
||||||
|
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||||
|
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||||
|
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
|
# (this is the forward diffusion process)
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
|
||||||
|
|
||||||
|
if args.v_parameterization:
|
||||||
|
# v-parameterization training
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
target = noise
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
if args.min_snr_gamma:
|
||||||
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
|
||||||
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
|
loss = loss * loss_weights
|
||||||
|
|
||||||
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
accelerator.backward(loss)
|
||||||
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
||||||
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
|
with torch.no_grad():
|
||||||
|
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
||||||
|
index_no_updates
|
||||||
|
]
|
||||||
|
|
||||||
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
# TODO: fix sample_images
|
||||||
|
# train_util.sample_images(
|
||||||
|
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||||
|
# )
|
||||||
|
|
||||||
|
current_loss = loss.detach().item()
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||||
|
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||||
|
logs["lr/d*lr"] = (
|
||||||
|
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||||
|
)
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
loss_total += current_loss
|
||||||
|
avr_loss = loss_total / (step + 1)
|
||||||
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
|
if global_step >= args.max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||||
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||||
|
|
||||||
|
if args.save_every_n_epochs is not None:
|
||||||
|
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||||
|
|
||||||
|
def save_func():
|
||||||
|
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
|
||||||
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
|
||||||
|
def remove_old_func(old_epoch_no):
|
||||||
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
|
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||||
|
if os.path.exists(old_ckpt_file):
|
||||||
|
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||||
|
os.remove(old_ckpt_file)
|
||||||
|
|
||||||
|
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||||
|
if saving and args.save_state:
|
||||||
|
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||||
|
|
||||||
|
# TODO: fix sample_images
|
||||||
|
# train_util.sample_images(
|
||||||
|
# accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||||
|
# )
|
||||||
|
|
||||||
|
# end of epoch
|
||||||
|
|
||||||
|
is_main_process = accelerator.is_main_process
|
||||||
|
if is_main_process:
|
||||||
|
text_encoder = unwrap_model(text_encoder)
|
||||||
|
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
if args.save_state:
|
||||||
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
|
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||||
|
|
||||||
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||||
|
ckpt_name = model_name + "." + args.save_model_as
|
||||||
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
|
print(f"save trained model to {ckpt_file}")
|
||||||
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
def save_weights(file, updated_embs, save_dtype):
|
||||||
|
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
|
||||||
|
updated_embs = updated_embs.chunk(16)
|
||||||
|
XTI_layers = [
|
||||||
|
"IN01",
|
||||||
|
"IN02",
|
||||||
|
"IN04",
|
||||||
|
"IN05",
|
||||||
|
"IN07",
|
||||||
|
"IN08",
|
||||||
|
"MID",
|
||||||
|
"OUT03",
|
||||||
|
"OUT04",
|
||||||
|
"OUT05",
|
||||||
|
"OUT06",
|
||||||
|
"OUT07",
|
||||||
|
"OUT08",
|
||||||
|
"OUT09",
|
||||||
|
"OUT10",
|
||||||
|
"OUT11",
|
||||||
|
]
|
||||||
|
state_dict = {}
|
||||||
|
for i, layer_name in enumerate(XTI_layers):
|
||||||
|
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
|
||||||
|
|
||||||
|
# if save_dtype is not None:
|
||||||
|
# for key in list(state_dict.keys()):
|
||||||
|
# v = state_dict[key]
|
||||||
|
# v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
|
# state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
save_file(state_dict, file)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file) # can be loaded in Web UI
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights(file):
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
data = load_file(file)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"NOT XTI: {file}")
|
||||||
|
|
||||||
|
if len(data.values()) != 16:
|
||||||
|
raise ValueError(f"NOT XTI: {file}")
|
||||||
|
|
||||||
|
emb = torch.concat([x for x in data.values()])
|
||||||
|
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
train_util.add_sd_models_arguments(parser)
|
||||||
|
train_util.add_dataset_arguments(parser, True, True, False)
|
||||||
|
train_util.add_training_arguments(parser, True)
|
||||||
|
train_util.add_optimizer_arguments(parser)
|
||||||
|
config_util.add_config_arguments(parser)
|
||||||
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_model_as",
|
||||||
|
type=str,
|
||||||
|
default="pt",
|
||||||
|
choices=[None, "ckpt", "pt", "safetensors"],
|
||||||
|
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token_string",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
||||||
|
)
|
||||||
|
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_object_template",
|
||||||
|
action="store_true",
|
||||||
|
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_style_template",
|
||||||
|
action="store_true",
|
||||||
|
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
train(args)
|
Loading…
Reference in New Issue
Block a user