diff --git a/README.md b/README.md index fb75f56..e6f9d37 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,36 @@ Once you have created the LoRA network you can generate images via auto1111 by i ## Change history +* 2023/01/16 (v20.3.0) + - Fix a part of LoRA modules are not trained when ``gradient_checkpointing`` is enabled. + - Add ``--save_last_n_epochs_state`` option. You can specify how many state folders to keep, apart from how many models to keep. Thanks to shirayu! + - Fix Text Encoder training stops at ``max_train_steps`` even if ``max_train_epochs`` is set in `train_db.py``. + - Added script to check LoRA weights. You can check weights by ``python networks\check_lora_weights.py ``. If some modules are not trained, the value is ``0.0`` like following. + - ``lora_te_text_model_encoder_layers_11_*`` is not trained with ``clip_skip=2``, so ``0.0`` is okay for these modules. + +- example result of ``check_lora_weights.py``, Text Encoder and a part of U-Net are not trained: +``` +number of LoRA-up modules: 264 +lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight,0.0 +lora_te_text_model_encoder_layers_0_mlp_fc2.lora_up.weight,0.0 +lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_up.weight,0.0 +: +lora_unet_down_blocks_2_attentions_1_transformer_blocks_0_ff_net_0_proj.lora_up.weight,0.0 +lora_unet_down_blocks_2_attentions_1_transformer_blocks_0_ff_net_2.lora_up.weight,0.0 +lora_unet_mid_block_attentions_0_proj_in.lora_up.weight,0.003503334941342473 +lora_unet_mid_block_attentions_0_proj_out.lora_up.weight,0.004308608360588551 +: +``` + +- all modules are trained: +``` +number of LoRA-up modules: 264 +lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight,0.0028684409335255623 +lora_te_text_model_encoder_layers_0_mlp_fc2.lora_up.weight,0.0029794853180646896 +lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_up.weight,0.002507600700482726 +lora_te_text_model_encoder_layers_0_self_attn_out_proj.lora_up.weight,0.002639499492943287 +: +``` * 2023/01/16 (v20.2.1): - Merging latest code update from kohya - Added `--max_train_epochs` and `--max_data_loader_n_workers` option for each training script. diff --git a/library/train_util.py b/library/train_util.py index 57ebf1b..aa65dc3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1029,6 +1029,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") + parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -1298,7 +1299,6 @@ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs - remove_epoch_no = None if saving: os.makedirs(args.output_dir, exist_ok=True) save_func() @@ -1306,7 +1306,7 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc if args.save_last_n_epochs is not None: remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs remove_old_func(remove_epoch_no) - return saving, remove_epoch_no + return saving def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae): @@ -1346,15 +1346,18 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: save_func = save_du remove_old_func = remove_du - saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) if saving and args.save_state: - save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no) + save_state_on_epoch_end(args, accelerator, model_name, epoch_no) -def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no): +def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): print("saving state.") accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) - if remove_epoch_no is not None: + + last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) if os.path.exists(state_dir_old): print(f"removing old state: {state_dir_old}") diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py new file mode 100644 index 0000000..1140e3b --- /dev/null +++ b/networks/check_lora_weights.py @@ -0,0 +1,31 @@ +import argparse +import os +import torch +from safetensors.torch import load_file + + +def main(file): + print(f"loading: {file}") + if os.path.splitext(file)[1] == '.safetensors': + sd = load_file(file) + else: + sd = torch.load(file, map_location='cpu') + + values = [] + + keys = list(sd.keys()) + for key in keys: + if 'lora_up' in key: + values.append((key, sd[key])) + print(f"number of LoRA-up modules: {len(values)}") + + for key, value in values: + print(f"{key},{torch.mean(torch.abs(value))}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") + args = parser.parse_args() + + main(args.file) diff --git a/train_db.py b/train_db.py index bbef3da..8ac503e 100644 --- a/train_db.py +++ b/train_db.py @@ -92,10 +92,7 @@ def train(args): gc.collect() # 学習を準備する:モデルを適切な状態にする - if args.stop_text_encoder_training is None: - args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end - - train_text_encoder = args.stop_text_encoder_training >= 0 + train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 text_encoder.requires_grad_(train_text_encoder) if not train_text_encoder: @@ -143,6 +140,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) diff --git a/train_network.py b/train_network.py index c0a881a..b2c7b57 100644 --- a/train_network.py +++ b/train_network.py @@ -166,6 +166,9 @@ def train(args): if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() text_encoder.train() + + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder.text_model.embeddings.requires_grad_(True) else: unet.eval() text_encoder.eval() @@ -364,9 +367,9 @@ def train(args): print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no) + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) # end of epoch