From c8f4c9d6e8289797873cd00ecdd661aebe7120fd Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 30 Jan 2023 08:26:15 -0500 Subject: [PATCH] Add support for lr_scheduler_num_cycles, lr_scheduler_power --- README.md | 2 +- library/common_gui.py | 2 +- lora_gui.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4342fe9..98018c9 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,7 @@ Then redo the installation instruction within the kohya_ss venv. ## Change history -* 2023/01/29 (v20.5.2): +* 2023/01/30 (v20.5.2): - Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev! - Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py`` * 2023/01/27 (v20.5.1): diff --git a/library/common_gui.py b/library/common_gui.py index 816fd9d..2595540 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -520,7 +520,7 @@ def gradio_advanced_training(): label='Shuffle caption', value=False ) keep_tokens = gr.Slider( - label='Keen n tokens', value='0', minimum=0, maximum=32, step=1 + label='Keep n tokens', value='0', minimum=0, maximum=32, step=1 ) use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) xformers = gr.Checkbox(label='Use xformers', value=True) diff --git a/lora_gui.py b/lora_gui.py index 49359e9..70dbfd0 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -91,6 +91,7 @@ def save_configuration( max_data_loader_n_workers, network_alpha, training_comment, keep_tokens, + lr_scheduler_num_cycles, lr_scheduler_power, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -180,6 +181,7 @@ def open_configuration( max_data_loader_n_workers, network_alpha, training_comment, keep_tokens, + lr_scheduler_num_cycles, lr_scheduler_power, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -253,6 +255,7 @@ def train_model( max_data_loader_n_workers, network_alpha, training_comment, keep_tokens, + lr_scheduler_num_cycles, lr_scheduler_power, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -395,6 +398,10 @@ def train_model( run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' if not output_name == '': run_cmd += f' --output_name="{output_name}"' + if not lr_scheduler_num_cycles == '': + run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"' + if not lr_scheduler_power == '': + run_cmd += f' --output_name="{lr_scheduler_power}"' run_cmd += run_cmd_training( learning_rate=learning_rate, @@ -646,6 +653,13 @@ def lora_tab( prior_loss_weight = gr.Number( label='Prior loss weight', value=1.0 ) + lr_scheduler_num_cycles = gr.Textbox( + label='LR number of cycles', placeholder='(Optional) For Cosine with restart and polynomial only' + ) + + lr_scheduler_power = gr.Textbox( + label='LR power', placeholder='(Optional) For Cosine with restart and polynomial only' + ) ( use_8bit_adam, xformers, @@ -736,6 +750,7 @@ def lora_tab( network_alpha, training_comment, keep_tokens, + lr_scheduler_num_cycles, lr_scheduler_power, ] button_open_config.click(