From 2ca17f69dd0fbb8aea03b02c9b3ff9929025aba8 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 22 Jan 2023 10:18:00 -0500 Subject: [PATCH] v20.4.0: Add support for `network_alpha` under the Training tab and support for `--training_comment` under the Folders tab. --- README.md | 18 ++++++++- gen_img_diffusers.py | 19 ++++----- library/train_util.py | 40 +++++++++++++++++-- lora_gui.py | 28 ++++++++++++- networks/check_lora_weights.py | 7 ++-- networks/extract_lora_from_models.py | 24 ++++++----- networks/lora.py | 53 ++++++++++++++++++++----- networks/merge_lora.py | 40 ++++++++++++++----- train_network.py | 59 ++++++++++++++++++++++------ train_network_README-ja.md | 2 +- 10 files changed, 227 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index e6f9d37..c074b2f 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,23 @@ Once you have created the LoRA network you can generate images via auto1111 by i ## Change history -* 2023/01/16 (v20.3.0) +* 2023/01/22 (v20.4.0): + - Add support for `network_alpha` under the Training tab and support for `--training_comment` under the Folders tab. + - Add ``--network_alpha`` option to specify ``alpha`` value to prevent underflows for stable training. Thanks to CCRcmcpe! + - Details of the issue are described in https://github.com/kohya-ss/sd-webui-additional-networks/issues/49 . + - The default value is ``1``, scale ``1 / rank (or dimension)``. Set same value as ``network_dim`` for same behavior to old version. + - LoRA with a large dimension (rank) seems to require a higher learning rate with ``alpha=1`` (e.g. 1e-3 for 128-dim, still investigating).  + - For generating images in Web UI, __the latest version of the extension ``sd-webui-additional-networks`` (v0.3.0 or later) is required for the models trained with this release or later.__ + - Add logging for the learning rate for U-Net and Text Encoder independently, and for running average epoch loss. Thanks to mgz-dev! + - Add more metadata such as dataset/reg image dirs, session ID, output name etc... See https://github.com/kohya-ss/sd-scripts/pull/77 for details. Thanks to space-nuko! + - __Now the metadata includes the folder name (the basename of the folder contains image files, not fullpath).__ If you do not want it, disable metadata storing with ``--no_metadata`` option. + - Add ``--training_comment`` option. You can specify an arbitrary string and refer to it by the extension. + +It seems that the Stable Diffusion web UI now supports image generation using the LoRA model learned in this repository. + +Note: At this time, it appears that models learned with version 0.4.0 are not supported. If you want to use the generation function of the web UI, please continue to use version 0.3.2. Also, it seems that LoRA models for SD2.x are not supported. + +* 2023/01/16 (v20.3.0): - Fix a part of LoRA modules are not trained when ``gradient_checkpointing`` is enabled. - Add ``--save_last_n_epochs_state`` option. You can specify how many state folders to keep, apart from how many models to keep. Thanks to shirayu! - Fix Text Encoder training stops at ``max_train_steps`` even if ``max_train_epochs`` is set in `train_db.py``. diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 4edfe0b..19c63ac 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1981,7 +1981,6 @@ def main(args): imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i] net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -1992,22 +1991,22 @@ def main(args): key, value = net_arg.split("=") net_kwargs[key] = value - network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs) - if network is None: - return - if args.network_weights and i < len(args.network_weights): network_weight = args.network_weights[i] print("load network weights from:", network_weight) - if os.path.splitext(network_weight)[1] == '.safetensors': + if model_util.is_safetensors(network_weight): from safetensors.torch import safe_open with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: print(f"metadata for: {network_weight}: {metadata}") - network.load_weights(network_weight) + network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return network.apply_to(text_encoder, unet) @@ -2518,16 +2517,14 @@ if __name__ == '__main__': parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') + help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') parser.add_argument("--opt_channels_last", action='store_true', - help='set channels last option to model / モデルにchannles lastを指定し最適化する') + help='set channels last option to model / モデルにchannels lastを指定し最適化する') parser.add_argument("--network_module", type=str, default=None, nargs='*', help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') parser.add_argument("--network_weights", type=str, default=None, nargs='*', help='Hypernetwork weights to load / Hypernetworkの重み') parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') - parser.add_argument("--network_dim", type=int, default=None, nargs='*', - help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') diff --git a/library/train_util.py b/library/train_util.py index aa65dc3..0fdbadc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -11,6 +11,7 @@ import glob import math import os import random +import hashlib from tqdm import tqdm import torch @@ -79,6 +80,11 @@ class BaseDataset(torch.utils.data.Dataset): self.debug_dataset = debug_dataset self.random_crop = random_crop self.token_padding_disabled = False + self.dataset_dirs_info = {} + self.reg_dataset_dirs_info = {} + self.enable_bucket = False + self.min_bucket_reso = None + self.max_bucket_reso = None self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -463,6 +469,8 @@ class DreamBoothDataset(BaseDataset): assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( (self.width, self.height), min_bucket_reso, max_bucket_reso) + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso else: self.bucket_resos = [(self.width, self.height)] self.bucket_aspect_ratios = [self.width / self.height] @@ -523,6 +531,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, False, img_path) self.register_image(info) + self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images @@ -539,6 +548,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, True, img_path) reg_infos.append(info) + self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: @@ -611,6 +621,8 @@ class FineTuningDataset(BaseDataset): self.num_train_images = len(metadata) * dataset_repeats self.num_reg_images = 0 + self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} + # check existence of all npz files if not self.color_aug: npz_any = False @@ -653,6 +665,8 @@ class FineTuningDataset(BaseDataset): assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( (self.width, self.height), min_bucket_reso, max_bucket_reso) + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso else: self.bucket_resos = [(self.width, self.height)] self.bucket_aspect_ratios = [self.width / self.height] @@ -665,6 +679,9 @@ class FineTuningDataset(BaseDataset): self.bucket_resos.sort() self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos] + self.min_bucket_reso = min([min(reso) for reso in resos]) + self.max_bucket_reso = max([max(reso) for reso in resos]) + def image_key_to_npz_file(self, image_key): base_name = os.path.splitext(image_key)[0] npz_file_norm = base_name + '.npz' @@ -749,9 +766,9 @@ def default(val, d): def model_hash(filename): + """Old model hash used by stable-diffusion-webui""" try: with open(filename, "rb") as file: - import hashlib m = hashlib.sha256() file.seek(0x100000) @@ -761,6 +778,18 @@ def model_hash(filename): return 'NOFILE' +def calculate_sha256(filename): + """New model hash used by stable-diffusion-webui""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 @@ -1029,7 +1058,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") - parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") + parser.add_argument("--save_last_n_epochs_state", type=int, default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -1048,8 +1078,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") - parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") + parser.add_argument("--max_train_epochs", type=int, default=None, + help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") + parser.add_argument("--max_data_loader_n_workers", type=int, default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする") diff --git a/lora_gui.py b/lora_gui.py index e35e193..2a7421a 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -88,6 +88,8 @@ def save_configuration( max_token_length, max_train_epochs, max_data_loader_n_workers, + network_alpha, + training_comment, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -175,6 +177,8 @@ def open_configuration( max_token_length, max_train_epochs, max_data_loader_n_workers, + network_alpha, + training_comment, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -246,6 +250,8 @@ def train_model( max_token_length, max_train_epochs, max_data_loader_n_workers, + network_alpha, + training_comment, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -358,6 +364,9 @@ def train_model( run_cmd += f' --resolution={max_resolution}' run_cmd += f' --output_dir="{output_dir}"' run_cmd += f' --logging_dir="{logging_dir}"' + run_cmd += f' --network_alpha="{network_alpha}"' + if not training_comment == '': + run_cmd += f' --training_comment="{training_comment}"' if not stop_text_encoder_training == 0: run_cmd += ( f' --stop_text_encoder_training={stop_text_encoder_training}' @@ -518,10 +527,15 @@ def lora_tab( with gr.Row(): output_name = gr.Textbox( label='Model output name', - placeholder='Name of the model to output', + placeholder='(Name of the model to output)', value='last', interactive=True, ) + training_comment = gr.Textbox( + label='Training comment', + placeholder='(Optional) Add training comment to be included in metadata', + interactive=True, + ) train_data_dir.change( remove_doublequote, inputs=[train_data_dir], @@ -588,11 +602,19 @@ def lora_tab( network_dim = gr.Slider( minimum=1, maximum=128, - label='Network Dimension', + label='Network Rank (Dimension)', value=8, step=1, interactive=True, ) + network_alpha = gr.Slider( + minimum=1, + maximum=128, + label='Network Alpha', + value=1, + step=1, + interactive=True, + ) with gr.Row(): max_resolution = gr.Textbox( label='Max resolution', @@ -703,6 +725,8 @@ def lora_tab( max_token_length, max_train_epochs, max_data_loader_n_workers, + network_alpha, + training_comment, ] button_open_config.click( diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 1140e3b..4ee3f57 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -15,12 +15,13 @@ def main(file): keys = list(sd.keys()) for key in keys: - if 'lora_up' in key: + if 'lora_up' in key or 'lora_down' in key: values.append((key, sd[key])) - print(f"number of LoRA-up modules: {len(values)}") + print(f"number of LoRA modules: {len(values)}") for key, value in values: - print(f"{key},{torch.mean(torch.abs(value))}") + value = value.to(torch.float32) + print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") if __name__ == '__main__': diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 0a4c3a0..84d705c 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -44,9 +44,9 @@ def svd(args): print(f"loading SD model : {args.model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) - # create LoRA network to extract weights - lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o) - lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t) + # create LoRA network to extract weights: Use dim (rank) as alpha + lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o) + lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -77,10 +77,10 @@ def svd(args): module_t = lora_t.org_module diff = module_t.weight - module_o.weight diff = diff.float() - + if args.device: diff = diff.to(args.device) - + diffs[lora_name] = diff # make LoRA with svd @@ -116,6 +116,9 @@ def svd(args): print(f"LoRA has {len(lora_sd)} weights.") for key in list(lora_sd.keys()): + if "alpha" in key: + continue + lora_name = key.split('.')[0] i = 0 if "lora_up" in key else 1 @@ -124,7 +127,7 @@ def svd(args): if len(lora_sd[key].size()) == 4: weights = weights.unsqueeze(2).unsqueeze(3) - assert weights.size() == lora_sd[key].size() + assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" lora_sd[key] = weights # load state dict to LoRA and save it @@ -135,7 +138,10 @@ def svd(args): if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) - lora_network_o.save_weights(args.save_to, save_dtype, {}) + # minimum metadata + metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} + + lora_network_o.save_weights(args.save_to, save_dtype, metadata) print(f"LoRA weights are saved to: {args.save_to}") @@ -151,8 +157,8 @@ if __name__ == '__main__': help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors") parser.add_argument("--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数(デフォルト4)") - parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う") + parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") args = parser.parse_args() svd(args) diff --git a/networks/lora.py b/networks/lora.py index 3f8244e..9243f1e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -13,9 +13,11 @@ class LoRAModule(torch.nn.Module): replaces forward method of the original Linear, instead of replacing the original Linear module. """ - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4): + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + """ if alpha == 0 or None, alpha is rank (no scaling). """ super().__init__() self.lora_name = lora_name + self.lora_dim = lora_dim if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels @@ -28,6 +30,12 @@ class LoRAModule(torch.nn.Module): self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える + # same as microsoft's torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.zeros_(self.lora_up.weight) @@ -41,13 +49,37 @@ class LoRAModule(torch.nn.Module): del self.org_module def forward(self, x): - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale -def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs): +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): if network_dim is None: network_dim = 4 # default - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim) + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + return network + + +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import load_file, safe_open + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location='cpu') + + # get dim (rank) + network_alpha = None + network_dim = None + for key, value in weights_sd.items(): + if network_alpha is None and 'alpha' in key: + network_alpha = value + if network_dim is None and 'lora_down' in key and len(value.size()) == 2: + network_dim = value.size()[0] + + if network_alpha is None: + network_alpha = network_dim + + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + network.weights_sd = weights_sd return network @@ -57,10 +89,11 @@ class LoRANetwork(torch.nn.Module): LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' - def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: + def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None: super().__init__() self.multiplier = multiplier self.lora_dim = lora_dim + self.alpha = alpha # create module instances def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: @@ -71,7 +104,7 @@ class LoRANetwork(torch.nn.Module): if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') - lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim) + lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha) loras.append(lora) return loras @@ -149,21 +182,21 @@ class LoRANetwork(torch.nn.Module): return params self.requires_grad_(True) - params = [] + all_params = [] if self.text_encoder_loras: param_data = {'params': enumerate_params(self.text_encoder_loras)} if text_encoder_lr is not None: param_data['lr'] = text_encoder_lr - params.append(param_data) + all_params.append(param_data) if self.unet_loras: param_data = {'params': enumerate_params(self.unet_loras)} if unet_lr is not None: param_data['lr'] = unet_lr - params.append(param_data) + all_params.append(param_data) - return params + return all_params def prepare_grad_etc(self, text_encoder, unet): self.requires_grad_(True) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index d873a8e..1d4cb3b 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -61,6 +61,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") + alpha_key = key[:key.index("lora_down")] + 'alpha' # find original module for this lora module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" @@ -73,14 +74,18 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): down_weight = lora_sd[key] up_weight = lora_sd[up_key] + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + # W <- W + U * D weight = module.weight if len(weight.size()) == 2: # linear - weight = weight + ratio * (up_weight @ down_weight) + weight = weight + ratio * (up_weight @ down_weight) * scale else: # conv2d - weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale module.weight = torch.nn.Parameter(weight) @@ -88,20 +93,35 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): def merge_lora_models(models, ratios, merge_dtype): merged_sd = {} + alpha = None + dim = None for model, ratio in zip(models, ratios): print(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) print(f"merging...") for key in lora_sd.keys(): - if key in merged_sd: - assert merged_sd[key].size() == lora_sd[key].size( - ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + if 'alpha' in key: + if key in merged_sd: + assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" + else: + alpha = lora_sd[key].detach().numpy() + merged_sd[key] = lora_sd[key] else: - merged_sd[key] = lora_sd[key] * ratio + if key in merged_sd: + assert merged_sd[key].size() == lora_sd[key].size( + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + else: + if "lora_down" in key: + dim = lora_sd[key].size()[0] + merged_sd[key] = lora_sd[key] * ratio - return merged_sd + print(f"dim (rank): {dim}, alpha: {alpha}") + if alpha is None: + alpha = dim + + return merged_sd, dim, alpha def merge(args): @@ -132,7 +152,7 @@ def merge(args): model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: - state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) print(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) @@ -145,7 +165,7 @@ if __name__ == '__main__': parser.add_argument("--save_precision", type=str, default=None, choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") parser.add_argument("--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度") + choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") parser.add_argument("--sd_model", type=str, default=None, help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") parser.add_argument("--save_to", type=str, default=None, diff --git a/train_network.py b/train_network.py index b2c7b57..d60ae9a 100644 --- a/train_network.py +++ b/train_network.py @@ -3,6 +3,9 @@ import argparse import gc import math import os +import random +import time +import json from tqdm import tqdm import torch @@ -18,7 +21,23 @@ def collate_fn(examples): return examples[0] +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if args.network_train_unet_only: + logs["lr/unet"] = lr_scheduler.get_last_lr()[0] + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + else: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder + + return logs + + def train(args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -88,7 +107,8 @@ def train(args): key, value = net_arg.split('=') net_kwargs[key] = value - network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs) + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return @@ -206,21 +226,26 @@ def train(args): print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, "ss_text_encoder_lr": args.text_encoder_lr, "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data + "ss_num_train_images": train_dataset.num_train_images, # includes repeating "ss_num_reg_images": train_dataset.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, "ss_batch_size_per_device": args.train_batch_size, "ss_total_batch_size": total_batch_size, + "ss_gradient_checkpointing": args.gradient_checkpointing, "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, "ss_max_train_steps": args.max_train_steps, "ss_lr_warmup_steps": args.lr_warmup_steps, "ss_lr_scheduler": args.lr_scheduler, "ss_network_module": args.network_module, - "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not use this value "ss_mixed_precision": args.mixed_precision, "ss_full_fp16": bool(args.full_fp16), "ss_v2": bool(args.v2), @@ -232,10 +257,14 @@ def train(args): "ss_random_crop": bool(args.random_crop), "ss_shuffle_caption": bool(args.shuffle_caption), "ss_cache_latents": bool(args.cache_latents), - "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT - "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset - "ss_max_bucket_reso": args.max_bucket_reso, - "ss_seed": args.seed + "ss_enable_bucket": bool(train_dataset.enable_bucket), + "ss_min_bucket_reso": train_dataset.min_bucket_reso, + "ss_max_bucket_reso": train_dataset.max_bucket_reso, + "ss_seed": args.seed, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), + "ss_training_comment": args.training_comment # will not be updated after training } # uncomment if another network is added @@ -246,6 +275,7 @@ def train(args): sd_model_name = args.pretrained_model_name_or_path if os.path.exists(sd_model_name): metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) sd_model_name = os.path.basename(sd_model_name) metadata["ss_sd_model_name"] = sd_model_name @@ -253,6 +283,7 @@ def train(args): vae_name = args.vae if os.path.exists(vae_name): metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) vae_name = os.path.basename(vae_name) metadata["ss_vae_name"] = vae_name @@ -333,20 +364,20 @@ def train(args): global_step += 1 current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} - accelerator.log(logs, step=global_step) - loss_total += current_loss avr_loss = loss_total / (step+1) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() @@ -417,11 +448,15 @@ if __name__ == '__main__': parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument("--network_dim", type=int, default=None, help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') + parser.add_argument("--network_alpha", type=float, default=1, + help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument("--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する") + parser.add_argument("--training_comment", type=str, default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") args = parser.parse_args() train(args) diff --git a/train_network_README-ja.md b/train_network_README-ja.md index 77ef4c1..8e329e9 100644 --- a/train_network_README-ja.md +++ b/train_network_README-ja.md @@ -138,7 +138,7 @@ v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマ ## 当リポジトリ内の画像生成スクリプトで生成する -gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim(省略可)の各オプションを追加してください。意味は学習時と同様です。 +gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。 --network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。