Merge latest sd-script updates

This commit is contained in:
bmaltais 2023-04-01 07:14:25 -04:00
parent d37aa6efad
commit 2eddd64b90
10 changed files with 2607 additions and 1408 deletions

View File

@ -192,8 +192,20 @@ This will store your a backup file with your current locally installed pip packa
## Change History ## Change History
* 2023/04/01 (v21.4.0)
- Fix an issue that `merge_lora.py` does not work with the latest version.
- Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights.
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)
- Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev!
- See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details.
- Use `train_textual_inversion_XTI.py` for training. The usage is almost the same as `train_textual_inversion.py`. However, sample image generation during training is not supported.
- Use `gen_img_diffusers.py` for image generation (I think Web UI is not supported). Specify the embedding with `--XTI_embeddings` option.
- Reduce RAM usage at startup in `train_network.py`. [#332](https://github.com/kohya-ss/sd-scripts/pull/332) Thank you guaneec!
- Support pre-merge for LoRA in `gen_img_diffusers.py`. Specify `--network_merge` option. Note that the `--am` option of the prompt option is no longer available with this option.
* 2023/04/01 (v21.3.9) * 2023/04/01 (v21.3.9)
- Update how setup is done on Windows by introducing a setup.bat script. This will make it easier to install/re-install on Windows if needed. Many thanks to @missionfloyd for his PR: https://github.com/bmaltais/kohya_ss/pull/496 - Update how setup is done on Windows by introducing a setup.bat script. This will make it easier to install/re-install on Windows if needed. Many thanks to @missionfloyd for his PR: https://github.com/bmaltais/kohya_ss/pull/496
- Fix issue with WD14 caption script by applying a custom fix to kohya_ss code.
* 2023/03/30 (v21.3.8) * 2023/03/30 (v21.3.8)
- Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481 - Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481
* 2023/03/29 (v21.3.7) * 2023/03/29 (v21.3.7)

209
XTI_hijack.py Normal file
View File

@ -0,0 +1,209 @@
import torch
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
def unet_forward_XTI(self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
down_i = 0
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
)
down_i += 2
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
# 5. up
up_i = 7
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
upsample_size=upsample_size,
)
up_i += 3
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)
def downblock_forward_XTI(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
output_states = ()
i = 0
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
output_states += (hidden_states,)
i += 1
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
def upblock_forward_XTI(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
i = 0
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
i += 1
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states

View File

@ -95,6 +95,8 @@ import library.train_util as train_util
import tools.original_control_net as original_control_net import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo from tools.original_control_net import ControlNetInfo
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14" TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
@ -491,6 +493,9 @@ class PipelineLike:
# Textual Inversion # Textual Inversion
self.token_replacements = {} self.token_replacements = {}
# XTI
self.token_replacements_XTI = {}
# CLIP guidance # CLIP guidance
self.clip_guidance_scale = clip_guidance_scale self.clip_guidance_scale = clip_guidance_scale
self.clip_image_guidance_scale = clip_image_guidance_scale self.clip_image_guidance_scale = clip_image_guidance_scale
@ -514,15 +519,26 @@ class PipelineLike:
def add_token_replacement(self, target_token_id, rep_token_ids): def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids self.token_replacements[target_token_id] = rep_token_ids
def replace_token(self, tokens): def replace_token(self, tokens, layer=None):
new_tokens = [] new_tokens = []
for token in tokens: for token in tokens:
if token in self.token_replacements: if token in self.token_replacements:
new_tokens.extend(self.token_replacements[token]) replacer_ = self.token_replacements[token]
if layer:
replacer = []
for r in replacer_:
if r in self.token_replacements_XTI:
replacer.append(self.token_replacements_XTI[r][layer])
else:
replacer = replacer_
new_tokens.extend(replacer)
else: else:
new_tokens.append(token) new_tokens.append(token)
return new_tokens return new_tokens
def add_token_replacement_XTI(self, target_token_id, rep_token_ids):
self.token_replacements_XTI[target_token_id] = rep_token_ids
def set_control_nets(self, ctrl_nets): def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets self.control_nets = ctrl_nets
@ -744,6 +760,7 @@ class PipelineLike:
" the batch size of `prompt`." " the batch size of `prompt`."
) )
if not self.token_replacements_XTI:
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self, pipe=self,
prompt=prompt, prompt=prompt,
@ -763,6 +780,42 @@ class PipelineLike:
**kwargs, **kwargs,
) )
if self.token_replacements_XTI:
text_embeddings_concat = []
for layer in [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]:
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
layer=layer,
**kwargs,
)
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings]))
else:
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]))
text_embeddings = torch.stack(text_embeddings_concat)
else:
if do_classifier_free_guidance: if do_classifier_free_guidance:
if negative_scale is None: if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
@ -1675,7 +1728,7 @@ def parse_prompt_attention(text):
return res return res
def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int): def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None):
r""" r"""
Tokenize a list of prompts and return its tokens with weights of each token. Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included. No padding, starting or ending token is included.
@ -1691,7 +1744,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
# tokenize and discard the starting and the ending token # tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1] token = pipe.tokenizer(word).input_ids[1:-1]
token = pipe.replace_token(token) token = pipe.replace_token(token, layer=layer)
text_token += token text_token += token
# copy the weight by length of token # copy the weight by length of token
@ -1807,6 +1860,7 @@ def get_weighted_text_embeddings(
skip_parsing: Optional[bool] = False, skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False, skip_weighting: Optional[bool] = False,
clip_skip=None, clip_skip=None,
layer=None,
**kwargs, **kwargs,
): ):
r""" r"""
@ -1837,11 +1891,11 @@ def get_weighted_text_embeddings(
prompt = [prompt] prompt = [prompt]
if not skip_parsing: if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
if uncond_prompt is not None: if uncond_prompt is not None:
if isinstance(uncond_prompt, str): if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt] uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer)
else: else:
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens] prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
@ -2229,6 +2283,7 @@ def main(args):
if network is None: if network is None:
return return
if not args.network_merge:
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
if args.opt_channels_last: if args.opt_channels_last:
@ -2236,6 +2291,9 @@ def main(args):
network.to(dtype).to(device) network.to(dtype).to(device)
networks.append(network) networks.append(network)
else:
network.merge_to(text_encoder, unet, dtype, device)
else: else:
networks = [] networks = []
@ -2289,6 +2347,11 @@ def main(args):
if args.diffusers_xformers: if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# Textual Inversionを処理する # Textual Inversionを処理する
if args.textual_inversion_embeddings: if args.textual_inversion_embeddings:
token_ids_embeds = [] token_ids_embeds = []
@ -2335,6 +2398,71 @@ def main(args):
for token_id, embed in zip(token_ids, embeds): for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed token_embeds[token_id] = embed
if args.XTI_embeddings:
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
token_ids_embeds_XTI = []
for embeds_file in args.XTI_embeddings:
if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file
data = load_file(embeds_file)
else:
data = torch.load(embeds_file, map_location="cpu")
if set(data.keys()) != set(XTI_layers):
raise ValueError("NOT XTI")
embeds = torch.concat(list(data.values()))
num_vectors_per_token = data["MID"].size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == num_vectors_per_token
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
# if num_vectors_per_token > 1:
pipe.add_token_replacement(token_ids[0], token_ids)
token_strings_XTI = []
for layer_name in XTI_layers:
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
token_ids_embeds_XTI.append((token_ids_XTI, embeds))
for t in token_ids:
t_XTI_dic = {}
for i, layer_name in enumerate(XTI_layers):
t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens
pipe.add_token_replacement_XTI(t, t_XTI_dic)
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
for token_ids, embeds in token_ids_embeds_XTI:
for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed
# promptを取得する # promptを取得する
if args.from_file is not None: if args.from_file is not None:
print(f"reading prompts from {args.from_file}") print(f"reading prompts from {args.from_file}")
@ -2983,6 +3111,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
) )
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument( parser.add_argument(
"--textual_inversion_embeddings", "--textual_inversion_embeddings",
type=str, type=str,
@ -2990,6 +3119,13 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*", nargs="*",
help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", help="Embeddings files of Textual Inversion / Textual Inversionのembeddings",
) )
parser.add_argument(
"--XTI_embeddings",
type=str,
default=None,
nargs="*",
help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings",
)
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
parser.add_argument( parser.add_argument(
"--max_embeddings_multiples", "--max_embeddings_multiples",

View File

@ -247,53 +247,42 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
# Retrieves the keys for the input blocks only # Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = { input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
for layer_id in range(num_input_blocks)
} }
# Retrieves the keys for the middle blocks only # Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = { middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
for layer_id in range(num_middle_blocks)
} }
# Retrieves the keys for the output blocks only # Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = { output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
for layer_id in range(num_output_blocks)
} }
for i in range(1, num_input_blocks): for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1) block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict: if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight" f"input_blocks.{i}.0.op.weight"
) )
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint( assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint( assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0] resnet_0 = middle_blocks[0]
attentions = middle_blocks[1] attentions = middle_blocks[1]
@ -307,9 +296,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
attentions_paths = renew_attention_paths(attentions) attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint( assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks): for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1) block_id = i // (config["layers_per_block"] + 1)
@ -332,9 +319,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint( assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# オリジナル: # オリジナル:
# if ["conv.weight", "conv.bias"] in output_block_list.values(): # if ["conv.weight", "conv.bias"] in output_block_list.values():
@ -363,9 +348,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
"old": f"output_blocks.{i}.1", "old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
} }
assign_to_checkpoint( assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else: else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths: for path in resnet_0_paths:
@ -416,15 +399,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
# Retrieves the keys for the encoder down blocks only # Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = { down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only # Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = { up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks): for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
@ -458,9 +437,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
for i in range(num_up_blocks): for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i block_id = num_up_blocks - 1 - i
resnets = [ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
@ -556,7 +533,7 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
text_model_dict = {} text_model_dict = {}
for key in keys: for key in keys:
if key.startswith("cond_stage_model.transformer"): if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
return text_model_dict return text_model_dict
@ -578,21 +555,21 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
elif ".mlp." in key: elif ".mlp." in key:
key = key.replace(".c_fc.", ".fc1.") key = key.replace(".c_fc.", ".fc1.")
key = key.replace(".c_proj.", ".fc2.") key = key.replace(".c_proj.", ".fc2.")
elif '.attn.out_proj' in key: elif ".attn.out_proj" in key:
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
elif '.attn.in_proj' in key: elif ".attn.in_proj" in key:
key = None # 特殊なので後で処理する key = None # 特殊なので後で処理する
else: else:
raise ValueError(f"unexpected key in SD: {key}") raise ValueError(f"unexpected key in SD: {key}")
elif '.positional_embedding' in key: elif ".positional_embedding" in key:
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
elif '.text_projection' in key: elif ".text_projection" in key:
key = None # 使われない??? key = None # 使われない???
elif '.logit_scale' in key: elif ".logit_scale" in key:
key = None # 使われない??? key = None # 使われない???
elif '.token_embedding' in key: elif ".token_embedding" in key:
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
elif '.ln_final' in key: elif ".ln_final" in key:
key = key.replace(".ln_final", ".final_layer_norm") key = key.replace(".ln_final", ".final_layer_norm")
return key return key
@ -600,7 +577,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
new_sd = {} new_sd = {}
for key in keys: for key in keys:
# remove resblocks 23 # remove resblocks 23
if '.resblocks.23.' in key: if ".resblocks.23." in key:
continue continue
new_key = convert_key(key) new_key = convert_key(key)
if new_key is None: if new_key is None:
@ -609,9 +586,9 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
# attnの変換 # attnの変換
for key in keys: for key in keys:
if '.resblocks.23.' in key: if ".resblocks.23." in key:
continue continue
if '.resblocks' in key and '.attn.in_proj_' in key: if ".resblocks" in key and ".attn.in_proj_" in key:
# 三つに分割 # 三つに分割
values = torch.chunk(checkpoint[key], 3) values = torch.chunk(checkpoint[key], 3)
@ -636,12 +613,14 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
new_sd["text_model.embeddings.position_ids"] = position_ids new_sd["text_model.embeddings.position_ids"] = position_ids
return new_sd return new_sd
# endregion # endregion
# region Diffusers->StableDiffusion の変換コード # region Diffusers->StableDiffusion の変換コード
# convert_diffusers_to_original_stable_diffusion をコピーして修正しているASL 2.0 # convert_diffusers_to_original_stable_diffusion をコピーして修正しているASL 2.0
def conv_transformer_to_linear(checkpoint): def conv_transformer_to_linear(checkpoint):
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
tf_keys = ["proj_in.weight", "proj_out.weight"] tf_keys = ["proj_in.weight", "proj_out.weight"]
@ -751,6 +730,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
# VAE Conversion # # VAE Conversion #
# ================# # ================#
def reshape_weight_for_sd(w): def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights # convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1) return w.reshape(*w.shape, 1, 1)
@ -827,23 +807,24 @@ def convert_vae_state_dict(vae_state_dict):
# region 自作のモデル読み書きなど # region 自作のモデル読み書きなど
def is_safetensors(path): def is_safetensors(path):
return os.path.splitext(path)[1].lower() == '.safetensors' return os.path.splitext(path)[1].lower() == ".safetensors"
def load_checkpoint_with_text_encoder_conversion(ckpt_path): def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# text encoderの格納形式が違うモデルに対応する ('text_model'がない) # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
TEXT_ENCODER_KEY_REPLACEMENTS = [ TEXT_ENCODER_KEY_REPLACEMENTS = [
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
] ]
if is_safetensors(ckpt_path): if is_safetensors(ckpt_path):
checkpoint = None checkpoint = None
state_dict = load_file(ckpt_path, "cpu") state_dict = load_file(ckpt_path) # , device) # may causes error
else: else:
checkpoint = torch.load(ckpt_path, map_location="cpu") checkpoint = torch.load(ckpt_path, map_location=device)
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"] state_dict = checkpoint["state_dict"]
else: else:
@ -854,7 +835,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
for key in state_dict.keys(): for key in state_dict.keys():
if key.startswith(rep_from): if key.startswith(rep_from):
new_key = rep_to + key[len(rep_from):] new_key = rep_to + key[len(rep_from) :]
key_reps.append((key, new_key)) key_reps.append((key, new_key))
for key, new_key in key_reps: for key, new_key in key_reps:
@ -865,18 +846,14 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
if dtype is not None:
for k, v in state_dict.items():
if type(v) is torch.Tensor:
state_dict[k] = v.to(dtype)
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(v2) unet_config = create_unet_diffusers_config(v2)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
unet = UNet2DConditionModel(**unet_config) unet = UNet2DConditionModel(**unet_config).to(device)
info = unet.load_state_dict(converted_unet_checkpoint) info = unet.load_state_dict(converted_unet_checkpoint)
print("loading u-net:", info) print("loading u-net:", info)
@ -884,7 +861,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
vae_config = create_vae_diffusers_config() vae_config = create_vae_diffusers_config()
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config).to(device)
info = vae.load_state_dict(converted_vae_checkpoint) info = vae.load_state_dict(converted_vae_checkpoint)
print("loading vae:", info) print("loading vae:", info)
@ -918,7 +895,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
logging.set_verbosity_error() # don't show annoying warning logging.set_verbosity_error() # don't show annoying warning
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
logging.set_verbosity_warning() logging.set_verbosity_warning()
info = text_model.load_state_dict(converted_text_encoder_checkpoint) info = text_model.load_state_dict(converted_text_encoder_checkpoint)
@ -944,17 +921,17 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
elif ".mlp." in key: elif ".mlp." in key:
key = key.replace(".fc1.", ".c_fc.") key = key.replace(".fc1.", ".c_fc.")
key = key.replace(".fc2.", ".c_proj.") key = key.replace(".fc2.", ".c_proj.")
elif '.self_attn.out_proj' in key: elif ".self_attn.out_proj" in key:
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
elif '.self_attn.' in key: elif ".self_attn." in key:
key = None # 特殊なので後で処理する key = None # 特殊なので後で処理する
else: else:
raise ValueError(f"unexpected key in DiffUsers model: {key}") raise ValueError(f"unexpected key in DiffUsers model: {key}")
elif '.position_embedding' in key: elif ".position_embedding" in key:
key = key.replace("embeddings.position_embedding.weight", "positional_embedding") key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif '.token_embedding' in key: elif ".token_embedding" in key:
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif 'final_layer_norm' in key: elif "final_layer_norm" in key:
key = key.replace("final_layer_norm", "ln_final") key = key.replace("final_layer_norm", "ln_final")
return key return key
@ -968,7 +945,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
# attnの変換 # attnの変換
for key in keys: for key in keys:
if 'layers' in key and 'q_proj' in key: if "layers" in key and "q_proj" in key:
# 三つを結合 # 三つを結合
key_q = key key_q = key
key_k = key.replace("q_proj", "k_proj") key_k = key.replace("q_proj", "k_proj")
@ -992,8 +969,8 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
# Diffusersに含まれない重みを作っておく # Diffusersに含まれない重みを作っておく
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
new_sd['logit_scale'] = torch.tensor(1) new_sd["logit_scale"] = torch.tensor(1)
return new_sd return new_sd
@ -1044,19 +1021,19 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
# Put together new checkpoint # Put together new checkpoint
key_count = len(state_dict.keys()) key_count = len(state_dict.keys())
new_ckpt = {'state_dict': state_dict} new_ckpt = {"state_dict": state_dict}
# epoch and global_step are sometimes not int # epoch and global_step are sometimes not int
try: try:
if 'epoch' in checkpoint: if "epoch" in checkpoint:
epochs += checkpoint['epoch'] epochs += checkpoint["epoch"]
if 'global_step' in checkpoint: if "global_step" in checkpoint:
steps += checkpoint['global_step'] steps += checkpoint["global_step"]
except: except:
pass pass
new_ckpt['epoch'] = epochs new_ckpt["epoch"] = epochs
new_ckpt['global_step'] = steps new_ckpt["global_step"] = steps
if is_safetensors(output_file): if is_safetensors(output_file):
# TODO Tensor以外のdictの値を削除したほうがいいか # TODO Tensor以外のdictの値を削除したほうがいいか
@ -1116,9 +1093,8 @@ def load_vae(vae_id, dtype):
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
else: else:
# StableDiffusion # StableDiffusion
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
else torch.load(vae_id, map_location="cpu")) vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
# vae only or full model # vae only or full model
full_model = False full_model = False
@ -1140,6 +1116,7 @@ def load_vae(vae_id, dtype):
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
return vae return vae
# endregion # endregion
@ -1174,7 +1151,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
return resos return resos
if __name__ == '__main__': if __name__ == "__main__":
resos = make_bucket_resolutions((512, 768)) resos = make_bucket_resolutions((512, 768))
print(len(resos)) print(len(resos))
print(resos) print(resos)

View File

@ -404,6 +404,8 @@ class BaseDataset(torch.utils.data.Dataset):
self.token_padding_disabled = False self.token_padding_disabled = False
self.tag_frequency = {} self.tag_frequency = {}
self.XTI_layers = None
self.token_strings = None
self.enable_bucket = False self.enable_bucket = False
self.bucket_manager: BucketManager = None # not initialized self.bucket_manager: BucketManager = None # not initialized
@ -464,6 +466,10 @@ class BaseDataset(torch.utils.data.Dataset):
def disable_token_padding(self): def disable_token_padding(self):
self.token_padding_disabled = True self.token_padding_disabled = True
def enable_XTI(self, layers=None, token_strings=None):
self.XTI_layers = layers
self.token_strings = token_strings
def add_replacement(self, str_from, str_to): def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to self.replacements[str_from] = str_to
@ -909,9 +915,22 @@ class BaseDataset(torch.utils.data.Dataset):
latents_list.append(latents) latents_list.append(latents)
caption = self.process_caption(subset, image_info.caption) caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption) captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption)) if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer)
else:
token_caption = self.get_input_ids(caption)
input_ids_list.append(token_caption)
example = {} example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights) example["loss_weights"] = torch.FloatTensor(loss_weights)
@ -1314,6 +1333,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# for dataset in self.datasets: # for dataset in self.datasets:
# dataset.make_buckets() # dataset.make_buckets()
def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1): def cache_latents(self, vae, vae_batch_size=1):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
@ -2617,14 +2640,15 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype): def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
name_or_path = args.pretrained_model_name_or_path name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format: if load_stable_diffusion_format:
print("load StableDiffusion checkpoint") print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
else: else:
# Diffusers model is loaded to CPU
print("load Diffusers pretrained models") print("load Diffusers pretrained models")
try: try:
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)

View File

@ -18,11 +18,11 @@ class LoRAModule(torch.nn.Module):
""" """
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
""" if alpha == 0 or None, alpha is rank (no scaling). """ """if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels in_dim = org_module.in_channels
out_dim = org_module.out_channels out_dim = org_module.out_channels
else: else:
@ -36,7 +36,7 @@ class LoRAModule(torch.nn.Module):
# else: # else:
self.lora_dim = lora_dim self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size kernel_size = org_module.kernel_size
stride = org_module.stride stride = org_module.stride
padding = org_module.padding padding = org_module.padding
@ -50,7 +50,7 @@ class LoRAModule(torch.nn.Module):
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's # same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
@ -66,6 +66,37 @@ class LoRAModule(torch.nn.Module):
self.org_module.forward = self.forward self.org_module.forward = self.forward
del self.org_module del self.org_module
def merge_to(self, sd, dtype, device):
# get up/down weight
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"].to(torch.float)
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
def set_region(self, region): def set_region(self, region):
self.region = region self.region = region
self.region_mask = None self.region_mask = None
@ -87,7 +118,7 @@ class LoRAModule(torch.nn.Module):
else: else:
seq_len = x.size()[1] seq_len = x.size()[1]
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
h = int(self.region.size()[0] / ratio + .5) h = int(self.region.size()[0] / ratio + 0.5)
w = seq_len // h w = seq_len // h
r = self.region.to(x.device) r = self.region.to(x.device)
@ -95,7 +126,7 @@ class LoRAModule(torch.nn.Module):
r = r.to(torch.float) r = r.to(torch.float)
r = r.unsqueeze(0).unsqueeze(1) r = r.unsqueeze(0).unsqueeze(1)
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w) # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear') r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
r = r.to(x.dtype) r = r.to(x.dtype)
if len(x.size()) == 3: if len(x.size()) == 3:
@ -111,8 +142,8 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
network_dim = 4 # default network_dim = 4 # default
# extract dim/alpha for conv2d, and block dim # extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get('conv_dim', None) conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get('conv_alpha', None) conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None: if conv_dim is not None:
conv_dim = int(conv_dim) conv_dim = int(conv_dim)
if conv_alpha is None: if conv_alpha is None:
@ -148,30 +179,38 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
""" """
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, network = LoRANetwork(
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha) text_encoder,
unet,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
)
return network return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
if weights_sd is None: if weights_sd is None:
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open
weights_sd = load_file(file) weights_sd = load_file(file)
else: else:
weights_sd = torch.load(file, map_location='cpu') weights_sd = torch.load(file, map_location="cpu")
# get dim/alpha mapping # get dim/alpha mapping
modules_dim = {} modules_dim = {}
modules_alpha = {} modules_alpha = {}
for key, value in weights_sd.items(): for key, value in weights_sd.items():
if '.' not in key: if "." not in key:
continue continue
lora_name = key.split('.')[0] lora_name = key.split(".")[0]
if 'alpha' in key: if "alpha" in key:
modules_alpha[lora_name] = value modules_alpha[lora_name] = value
elif 'lora_down' in key: elif "lora_down" in key:
dim = value.size()[0] dim = value.size()[0]
modules_dim[lora_name] = dim modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim) # print(lora_name, value.size(), dim)
@ -191,10 +230,21 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = "lora_te"
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None: def __init__(
self,
text_encoder,
unet,
multiplier=1.0,
lora_dim=4,
alpha=1,
conv_lora_dim=None,
conv_alpha=None,
modules_dim=None,
modules_alpha=None,
) -> None:
super().__init__() super().__init__()
self.multiplier = multiplier self.multiplier = multiplier
@ -225,8 +275,8 @@ class LoRANetwork(torch.nn.Module):
is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d: if is_linear or is_conv2d:
lora_name = prefix + '.' + name + '.' + child_name lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace('.', '_') lora_name = lora_name.replace(".", "_")
if modules_dim is not None: if modules_dim is not None:
if lora_name not in modules_dim: if lora_name not in modules_dim:
@ -247,8 +297,9 @@ class LoRANetwork(torch.nn.Module):
loras.append(lora) loras.append(lora)
return loras return loras
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, self.text_encoder_loras = create_modules(
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
@ -273,11 +324,12 @@ class LoRANetwork(torch.nn.Module):
lora.multiplier = self.multiplier lora.multiplier = self.multiplier
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open
self.weights_sd = load_file(file) self.weights_sd = load_file(file)
else: else:
self.weights_sd = torch.load(file, map_location='cpu') self.weights_sd = torch.load(file, map_location="cpu")
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
if self.weights_sd: if self.weights_sd:
@ -291,12 +343,16 @@ class LoRANetwork(torch.nn.Module):
if apply_text_encoder is None: if apply_text_encoder is None:
apply_text_encoder = weights_has_text_encoder apply_text_encoder = weights_has_text_encoder
else: else:
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" assert (
apply_text_encoder == weights_has_text_encoder
), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
if apply_unet is None: if apply_unet is None:
apply_unet = weights_has_unet apply_unet = weights_has_unet
else: else:
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" assert (
apply_unet == weights_has_unet
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
else: else:
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
@ -319,6 +375,35 @@ class LoRANetwork(torch.nn.Module):
info = self.load_state_dict(self.weights_sd, False) info = self.load_state_dict(self.weights_sd, False)
print(f"weights are loaded: {info}") print(f"weights are loaded: {info}")
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, dtype, device):
assert self.weights_sd is not None, "weights are not loaded"
apply_text_encoder = apply_unet = False
for key in self.weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in self.weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
def enable_gradient_checkpointing(self): def enable_gradient_checkpointing(self):
# not supported # not supported
pass pass
@ -334,15 +419,15 @@ class LoRANetwork(torch.nn.Module):
all_params = [] all_params = []
if self.text_encoder_loras: if self.text_encoder_loras:
param_data = {'params': enumerate_params(self.text_encoder_loras)} param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None: if text_encoder_lr is not None:
param_data['lr'] = text_encoder_lr param_data["lr"] = text_encoder_lr
all_params.append(param_data) all_params.append(param_data)
if self.unet_loras: if self.unet_loras:
param_data = {'params': enumerate_params(self.unet_loras)} param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None: if unet_lr is not None:
param_data['lr'] = unet_lr param_data["lr"] = unet_lr
all_params.append(param_data) all_params.append(param_data)
return all_params return all_params
@ -368,7 +453,7 @@ class LoRANetwork(torch.nn.Module):
v = v.detach().clone().to("cpu").to(dtype) v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v state_dict[key] = v
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file from safetensors.torch import save_file
# Precalculate model hashes to save time on indexing # Precalculate model hashes to save time on indexing
@ -382,7 +467,7 @@ class LoRANetwork(torch.nn.Module):
else: else:
torch.save(state_dict, file) torch.save(state_dict, file)
@ staticmethod @staticmethod
def set_regions(networks, image): def set_regions(networks, image):
image = image.astype(np.float32) / 255.0 image = image.astype(np.float32) / 255.0
for i, network in enumerate(networks[:3]): for i, network in enumerate(networks[:3]):

View File

@ -1,4 +1,3 @@
import math import math
import argparse import argparse
import os import os
@ -9,10 +8,10 @@ import lora
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name) sd = load_file(file_name)
else: else:
sd = torch.load(file_name, map_location='cpu') sd = torch.load(file_name, map_location="cpu")
for key in list(sd.keys()): for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor: if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype) sd[key] = sd[key].to(dtype)
@ -25,7 +24,7 @@ def save_to_file(file_name, model, state_dict, dtype):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name) save_file(model, file_name)
else: else:
torch.save(model, file_name) torch.save(model, file_name)
@ -43,14 +42,16 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else: else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
for name, module in root_module.named_modules(): for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules: if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules(): for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + '.' + name + '.' + child_name lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace('.', '_') lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
@ -61,10 +62,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
for key in lora_sd.keys(): for key in lora_sd.keys():
if "lora_down" in key: if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha' alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora # find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module: if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}") print(f"no module found for LoRA weight: {key}")
continue continue
@ -86,8 +87,12 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
weight = weight + ratio * (up_weight @ down_weight) * scale weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1): elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1 # conv2d 1x1
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) weight = (
).unsqueeze(2).unsqueeze(3) * scale weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else: else:
# conv2d 3x3 # conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
@ -110,14 +115,14 @@ def merge_lora_models(models, ratios, merge_dtype):
alphas = {} # alpha for current model alphas = {} # alpha for current model
dims = {} # dims for current model dims = {} # dims for current model
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key: if "alpha" in key:
lora_module_name = key[:key.rfind(".alpha")] lora_module_name = key[: key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy()) alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas: if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha base_alphas[lora_module_name] = alpha
elif "lora_down" in key: elif "lora_down" in key:
lora_module_name = key[:key.rfind(".lora_down")] lora_module_name = key[: key.rfind(".lora_down")]
dim = lora_sd[key].size()[0] dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim dims[lora_module_name] = dim
if lora_module_name not in base_dims: if lora_module_name not in base_dims:
@ -135,10 +140,10 @@ def merge_lora_models(models, ratios, merge_dtype):
# merge # merge
print(f"merging...") print(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key: if "alpha" in key:
continue continue
lora_module_name = key[:key.rfind(".lora_")] lora_module_name = key[: key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name] base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name] alpha = alphas[lora_module_name]
@ -146,7 +151,8 @@ def merge_lora_models(models, ratios, merge_dtype):
scale = math.sqrt(alpha / base_alpha) * ratio scale = math.sqrt(alpha / base_alpha) * ratio
if key in merged_sd: if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size( assert (
merged_sd[key].size() == lora_sd[key].size()
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else: else:
@ -167,11 +173,11 @@ def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p): def str_to_dtype(p):
if p == 'float': if p == "float":
return torch.float return torch.float
if p == 'fp16': if p == "fp16":
return torch.float16 return torch.float16
if p == 'bf16': if p == "bf16":
return torch.bfloat16 return torch.bfloat16
return None return None
@ -188,8 +194,7 @@ def merge(args):
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}") print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae)
args.sd_model, 0, 0, save_dtype, vae)
else: else:
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
@ -199,25 +204,39 @@ def merge(args):
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') parser.add_argument(
parser.add_argument("--save_precision", type=str, default=None, "--save_precision",
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") type=str,
parser.add_argument("--precision", type=str, default="float", default=None,
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨") choices=[None, "float", "fp16", "bf16"],
parser.add_argument("--sd_model", type=str, default=None, help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") )
parser.add_argument("--save_to", type=str, default=None, parser.add_argument(
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") "--precision",
parser.add_argument("--models", type=str, nargs='*', type=str,
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") default="float",
parser.add_argument("--ratios", type=float, nargs='*', choices=["float", "fp16", "bf16"],
help="ratios for each model / それぞれのLoRAモデルの比率") help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--sd_model",
type=str,
default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
return parser return parser
if __name__ == '__main__': if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()

80
tools/merge_lycoris.py Normal file
View File

@ -0,0 +1,80 @@
import os
import sys
import argparse
import torch
from lycoris.utils import merge_loha, merge_locon
from lycoris.kohya_model_utils import (
load_models_from_stable_diffusion_checkpoint,
save_stable_diffusion_checkpoint,
load_file
)
import gradio as gr
def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, weight):
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
if lycoris_model.rsplit('.', 1)[-1] == 'safetensors':
lyco = load_file(lycoris_model)
else:
lyco = torch.load(lycoris_model)
algo = None
for key in lyco:
if 'hada' in key:
algo = 'loha'
break
elif 'lora_up' in key:
algo = 'lora'
break
else:
raise NotImplementedError('Cannot find the algo for this lycoris model file.')
dtype_str = dtype.replace('fp', 'float').replace('bf', 'bfloat')
dtype = {
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
'bfloat': torch.bfloat16,
'bfloat16': torch.bfloat16,
}.get(dtype_str, None)
if dtype is None:
raise ValueError(f'Cannot Find the dtype "{dtype}"')
if algo == 'loha':
merge_loha(base, lyco, weight, device)
elif algo == 'lora':
merge_locon(base, lyco, weight, device)
save_stable_diffusion_checkpoint(
is_v2, output_name,
base[0], base[2],
None, 0, 0, dtype,
base[1]
)
return output_name
def main():
iface = gr.Interface(
fn=merge_models,
inputs=[
gr.inputs.Textbox(label="Base Model Path"),
gr.inputs.Textbox(label="Lycoris Model Path"),
gr.inputs.Textbox(label="Output Model Path", default='./out.pt'),
gr.inputs.Checkbox(label="Is base model SD V2?", default=False),
gr.inputs.Textbox(label="Device", default='cpu'),
gr.inputs.Dropdown(choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'], label="Dtype", default='float'),
gr.inputs.Number(label="Weight", default=1.0)
],
outputs=gr.outputs.Textbox(label="Merged Model Path"),
title="Model Merger",
description="Merge Lycoris and Stable Diffusion models",
)
iface.launch()
if __name__ == '__main__':
main()

View File

@ -127,12 +127,25 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) for pi in range(accelerator.state.num_processes):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)
# work on low-ram device # work on low-ram device
if args.lowram: if args.lowram:
text_encoder.to("cuda") text_encoder.to(accelerator.device)
unet.to("cuda") unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)

View File

@ -0,0 +1,644 @@
import importlib
import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
def train(args):
if args.output_name is None:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
print(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
cache_latents = args.cache_latents
if args.seed is not None:
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# Convert the init_word to token_id
if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
)
else:
init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == args.num_vectors_per_token
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
token_strings_XTI = []
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
for layer_name in XTI_layers:
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
print(f"tokens are added (XTI): {token_ids_XTI}")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_ids is not None:
for i, token_id in enumerate(token_ids_XTI):
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids_XTI, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print("use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
for tmpl in templates:
captions.append(tmpl.format(replace_to))
train_dataset_group.add_replacement("", captions)
if args.num_vectors_per_token > 1:
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
else:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
else:
unet.eval()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = torch.stack(
[
train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
for s in torch.split(input_ids, 1, dim=1)
]
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# TODO: fix sample_images
# train_util.sample_images(
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
# )
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
# TODO: fix sample_images
# train_util.sample_images(
# accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
# )
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
print("model saved.")
def save_weights(file, updated_embs, save_dtype):
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
updated_embs = updated_embs.chunk(16)
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
state_dict = {}
for i, layer_name in enumerate(XTI_layers):
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
# if save_dtype is not None:
# for key in list(state_dict.keys()):
# v = state_dict[key]
# v = v.detach().clone().to("cpu").to(save_dtype)
# state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file)
else:
torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
data = load_file(file)
else:
raise ValueError(f"NOT XTI: {file}")
if len(data.values()) != 16:
raise ValueError(f"NOT XTI: {file}")
emb = torch.concat([x for x in data.values()])
return emb
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="pt",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
parser.add_argument(
"--token_string",
type=str,
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--use_object_template",
action="store_true",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
)
parser.add_argument(
"--use_style_template",
action="store_true",
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)