Add support for lr_scheduler_num_cycles, lr_scheduler_power

This commit is contained in:
bmaltais 2023-01-30 08:26:15 -05:00
parent 2ec7432440
commit c8f4c9d6e8
3 changed files with 17 additions and 2 deletions

View File

@ -143,7 +143,7 @@ Then redo the installation instruction within the kohya_ss venv.
## Change history ## Change history
* 2023/01/29 (v20.5.2): * 2023/01/30 (v20.5.2):
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev! - Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev!
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py`` - Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py``
* 2023/01/27 (v20.5.1): * 2023/01/27 (v20.5.1):

View File

@ -520,7 +520,7 @@ def gradio_advanced_training():
label='Shuffle caption', value=False label='Shuffle caption', value=False
) )
keep_tokens = gr.Slider( keep_tokens = gr.Slider(
label='Keen n tokens', value='0', minimum=0, maximum=32, step=1 label='Keep n tokens', value='0', minimum=0, maximum=32, step=1
) )
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True) xformers = gr.Checkbox(label='Use xformers', value=True)

View File

@ -91,6 +91,7 @@ def save_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, keep_tokens, training_comment, keep_tokens,
lr_scheduler_num_cycles, lr_scheduler_power,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -180,6 +181,7 @@ def open_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, keep_tokens, training_comment, keep_tokens,
lr_scheduler_num_cycles, lr_scheduler_power,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -253,6 +255,7 @@ def train_model(
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, keep_tokens, training_comment, keep_tokens,
lr_scheduler_num_cycles, lr_scheduler_power,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -395,6 +398,10 @@ def train_model(
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
if not output_name == '': if not output_name == '':
run_cmd += f' --output_name="{output_name}"' run_cmd += f' --output_name="{output_name}"'
if not lr_scheduler_num_cycles == '':
run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"'
if not lr_scheduler_power == '':
run_cmd += f' --output_name="{lr_scheduler_power}"'
run_cmd += run_cmd_training( run_cmd += run_cmd_training(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -646,6 +653,13 @@ def lora_tab(
prior_loss_weight = gr.Number( prior_loss_weight = gr.Number(
label='Prior loss weight', value=1.0 label='Prior loss weight', value=1.0
) )
lr_scheduler_num_cycles = gr.Textbox(
label='LR number of cycles', placeholder='(Optional) For Cosine with restart and polynomial only'
)
lr_scheduler_power = gr.Textbox(
label='LR power', placeholder='(Optional) For Cosine with restart and polynomial only'
)
( (
use_8bit_adam, use_8bit_adam,
xformers, xformers,
@ -736,6 +750,7 @@ def lora_tab(
network_alpha, network_alpha,
training_comment, training_comment,
keep_tokens, keep_tokens,
lr_scheduler_num_cycles, lr_scheduler_power,
] ]
button_open_config.click( button_open_config.click(