Add support for lr_scheduler_num_cycles, lr_scheduler_power
This commit is contained in:
parent
2ec7432440
commit
c8f4c9d6e8
@ -143,7 +143,7 @@ Then redo the installation instruction within the kohya_ss venv.
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
* 2023/01/29 (v20.5.2):
|
* 2023/01/30 (v20.5.2):
|
||||||
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev!
|
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev!
|
||||||
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py``
|
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py``
|
||||||
* 2023/01/27 (v20.5.1):
|
* 2023/01/27 (v20.5.1):
|
||||||
|
@ -520,7 +520,7 @@ def gradio_advanced_training():
|
|||||||
label='Shuffle caption', value=False
|
label='Shuffle caption', value=False
|
||||||
)
|
)
|
||||||
keep_tokens = gr.Slider(
|
keep_tokens = gr.Slider(
|
||||||
label='Keen n tokens', value='0', minimum=0, maximum=32, step=1
|
label='Keep n tokens', value='0', minimum=0, maximum=32, step=1
|
||||||
)
|
)
|
||||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
|
15
lora_gui.py
15
lora_gui.py
@ -91,6 +91,7 @@ def save_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment, keep_tokens,
|
training_comment, keep_tokens,
|
||||||
|
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -180,6 +181,7 @@ def open_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment, keep_tokens,
|
training_comment, keep_tokens,
|
||||||
|
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -253,6 +255,7 @@ def train_model(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment, keep_tokens,
|
training_comment, keep_tokens,
|
||||||
|
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -395,6 +398,10 @@ def train_model(
|
|||||||
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
||||||
if not output_name == '':
|
if not output_name == '':
|
||||||
run_cmd += f' --output_name="{output_name}"'
|
run_cmd += f' --output_name="{output_name}"'
|
||||||
|
if not lr_scheduler_num_cycles == '':
|
||||||
|
run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"'
|
||||||
|
if not lr_scheduler_power == '':
|
||||||
|
run_cmd += f' --output_name="{lr_scheduler_power}"'
|
||||||
|
|
||||||
run_cmd += run_cmd_training(
|
run_cmd += run_cmd_training(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -646,6 +653,13 @@ def lora_tab(
|
|||||||
prior_loss_weight = gr.Number(
|
prior_loss_weight = gr.Number(
|
||||||
label='Prior loss weight', value=1.0
|
label='Prior loss weight', value=1.0
|
||||||
)
|
)
|
||||||
|
lr_scheduler_num_cycles = gr.Textbox(
|
||||||
|
label='LR number of cycles', placeholder='(Optional) For Cosine with restart and polynomial only'
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler_power = gr.Textbox(
|
||||||
|
label='LR power', placeholder='(Optional) For Cosine with restart and polynomial only'
|
||||||
|
)
|
||||||
(
|
(
|
||||||
use_8bit_adam,
|
use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
@ -736,6 +750,7 @@ def lora_tab(
|
|||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment,
|
training_comment,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
Loading…
Reference in New Issue
Block a user