Adding support for Lion optimizer in gui

This commit is contained in:
bmaltais 2023-02-19 20:13:03 -05:00
parent bb57c1a36e
commit 758bfe85dc
7 changed files with 151 additions and 2 deletions

View File

@ -89,6 +89,7 @@ def save_configuration(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -179,6 +180,7 @@ def open_configuration(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -253,6 +255,7 @@ def train_model(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -397,6 +400,7 @@ def train_model(
seed=seed,
caption_extension=caption_extension,
cache_latents=cache_latents,
optimizer=optimizer
)
run_cmd += run_cmd_advanced_training(
@ -541,6 +545,7 @@ def dreambooth_tab(
seed,
caption_extension,
cache_latents,
optimizer,
) = gradio_training(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
@ -668,6 +673,7 @@ def dreambooth_tab(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
]
button_open_config.click(

View File

@ -85,6 +85,7 @@ def save_configuration(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -181,6 +182,7 @@ def open_config_file(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -262,6 +264,7 @@ def train_model(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
# create caption json file
if generate_caption_database:
@ -386,6 +389,7 @@ def train_model(
seed=seed,
caption_extension=caption_extension,
cache_latents=cache_latents,
optimizer=optimizer,
)
run_cmd += run_cmd_advanced_training(
@ -564,6 +568,7 @@ def finetune_tab():
seed,
caption_extension,
cache_latents,
optimizer,
) = gradio_training(learning_rate_value='1e-5')
with gr.Row():
dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
@ -661,6 +666,7 @@ def finetune_tab():
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
]
button_run.click(train_model, inputs=settings_list)

View File

@ -445,6 +445,7 @@ def gradio_training(
value=2,
)
seed = gr.Textbox(label='Seed', value=1234)
cache_latents = gr.Checkbox(label='Cache latent', value=True)
with gr.Row():
learning_rate = gr.Textbox(
label='Learning rate', value=learning_rate_value
@ -464,7 +465,15 @@ def gradio_training(
lr_warmup = gr.Textbox(
label='LR warmup (% of steps)', value=lr_warmup_value
)
cache_latents = gr.Checkbox(label='Cache latent', value=True)
optimizer = gr.Dropdown(
label='Optimizer',
choices=[
'AdamW',
'Lion',
],
value="AdamW",
interactive=True,
)
return (
learning_rate,
lr_scheduler,
@ -478,6 +487,7 @@ def gradio_training(
seed,
caption_extension,
cache_latents,
optimizer,
)
@ -512,10 +522,34 @@ def run_cmd_training(**kwargs):
if kwargs.get('caption_extension')
else '',
' --cache_latents' if kwargs.get('cache_latents') else '',
' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
]
run_cmd = ''.join(options)
return run_cmd
# # This function takes a dictionary of keyword arguments and returns a string that can be used to run a command-line training script
# def run_cmd_training(**kwargs):
# arg_map = {
# 'learning_rate': ' --learning_rate="{}"',
# 'lr_scheduler': ' --lr_scheduler="{}"',
# 'lr_warmup_steps': ' --lr_warmup_steps="{}"',
# 'train_batch_size': ' --train_batch_size="{}"',
# 'max_train_steps': ' --max_train_steps="{}"',
# 'save_every_n_epochs': ' --save_every_n_epochs="{}"',
# 'mixed_precision': ' --mixed_precision="{}"',
# 'save_precision': ' --save_precision="{}"',
# 'seed': ' --seed="{}"',
# 'caption_extension': ' --caption_extension="{}"',
# 'cache_latents': ' --cache_latents',
# 'optimizer': ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
# }
# options = [arg_map[key].format(value) for key, value in kwargs.items() if key in arg_map and value]
# cmd = ''.join(options)
# return cmd
def gradio_advanced_training():
with gr.Row():
@ -664,3 +698,34 @@ def run_cmd_advanced_training(**kwargs):
]
run_cmd = ''.join(options)
return run_cmd
# def run_cmd_advanced_training(**kwargs):
# arg_map = {
# 'max_train_epochs': ' --max_train_epochs="{}"',
# 'max_data_loader_n_workers': ' --max_data_loader_n_workers="{}"',
# 'max_token_length': ' --max_token_length={}' if int(kwargs.get('max_token_length', 75)) > 75 else '',
# 'clip_skip': ' --clip_skip={}' if int(kwargs.get('clip_skip', 1)) > 1 else '',
# 'resume': ' --resume="{}"',
# 'keep_tokens': ' --keep_tokens="{}"' if int(kwargs.get('keep_tokens', 0)) > 0 else '',
# 'caption_dropout_every_n_epochs': ' --caption_dropout_every_n_epochs="{}"' if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 else '',
# 'caption_dropout_rate': ' --caption_dropout_rate="{}"' if float(kwargs.get('caption_dropout_rate', 0)) > 0 else '',
# 'bucket_reso_steps': ' --bucket_reso_steps={:d}' if int(kwargs.get('bucket_reso_steps', 64)) >= 1 else '',
# 'save_state': ' --save_state',
# 'mem_eff_attn': ' --mem_eff_attn',
# 'color_aug': ' --color_aug',
# 'flip_aug': ' --flip_aug',
# 'shuffle_caption': ' --shuffle_caption',
# 'gradient_checkpointing': ' --gradient_checkpointing',
# 'full_fp16': ' --full_fp16',
# 'xformers': ' --xformers',
# 'use_8bit_adam': ' --use_8bit_adam',
# 'persistent_data_loader_workers': ' --persistent_data_loader_workers',
# 'bucket_no_upscale': ' --bucket_no_upscale',
# 'random_crop': ' --random_crop',
# }
# options = [arg_map[key].format(value) for key, value in kwargs.items() if key in arg_map and value]
# cmd = ''.join(options)
# return cmd

View File

@ -100,6 +100,7 @@ def save_configuration(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -197,6 +198,7 @@ def open_configuration(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -278,6 +280,7 @@ def train_model(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -457,6 +460,7 @@ def train_model(
seed=seed,
caption_extension=caption_extension,
cache_latents=cache_latents,
optimizer=optimizer,
)
run_cmd += run_cmd_advanced_training(
@ -609,6 +613,7 @@ def lora_tab(
seed,
caption_extension,
cache_latents,
optimizer,
) = gradio_training(
learning_rate_value='0.0001',
lr_scheduler_value='cosine',
@ -778,6 +783,7 @@ def lora_tab(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
]
button_open_config.click(

View File

@ -0,0 +1,59 @@
{
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
"v2": false,
"v_parameterization": false,
"logging_dir": "D:\\dataset\\marty_mcfly\\1985\\lora/log",
"train_data_dir": "D:\\dataset\\marty_mcfly\\1985\\lora\\img_gan",
"reg_data_dir": "",
"output_dir": "D:/lora/sd1.5/marty_mcfly",
"max_resolution": "512,512",
"learning_rate": "0.00003333",
"lr_scheduler": "cosine",
"lr_warmup": "0",
"train_batch_size": 8,
"epoch": "1",
"save_every_n_epochs": "1",
"mixed_precision": "bf16",
"save_precision": "fp16",
"seed": "1234",
"num_cpu_threads_per_process": 2,
"cache_latents": false,
"caption_extension": "",
"enable_bucket": true,
"gradient_checkpointing": false,
"full_fp16": false,
"no_token_padding": false,
"stop_text_encoder_training": 0,
"use_8bit_adam": false,
"xformers": true,
"save_model_as": "safetensors",
"shuffle_caption": false,
"save_state": false,
"resume": "",
"prior_loss_weight": 1.0,
"text_encoder_lr": "0.000016666",
"unet_lr": "0.00003333",
"network_dim": 128,
"lora_network_weights": "",
"color_aug": false,
"flip_aug": false,
"clip_skip": "1",
"gradient_accumulation_steps": 1.0,
"mem_eff_attn": false,
"output_name": "mrtmcfl_v2.0",
"model_list": "runwayml/stable-diffusion-v1-5",
"max_token_length": "75",
"max_train_epochs": "",
"max_data_loader_n_workers": "0",
"network_alpha": 128,
"training_comment": "",
"keep_tokens": "0",
"lr_scheduler_num_cycles": "",
"lr_scheduler_power": "",
"persistent_data_loader_workers": false,
"bucket_no_upscale": true,
"random_crop": true,
"bucket_reso_steps": 64.0,
"caption_dropout_every_n_epochs": 0.0,
"caption_dropout_rate": 0.1
}

View File

@ -13,6 +13,7 @@ gradio==3.16.2
altair==4.2.2
easygui==0.98.3
tk==0.1.0
lion-pytorch==0.0.6
# for BLIP captioning
requests==2.28.2
timm==0.6.12
@ -21,6 +22,6 @@ fairscale==0.4.13
# tensorflow<2.11
tensorflow==2.10.1
huggingface-hub==0.12.0
xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
# xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
# for kohya_ss library
.

View File

@ -95,6 +95,7 @@ def save_configuration(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -195,6 +196,7 @@ def open_configuration(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -275,6 +277,7 @@ def train_model(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -434,6 +437,7 @@ def train_model(
seed=seed,
caption_extension=caption_extension,
cache_latents=cache_latents,
optimizer=optimizer,
)
run_cmd += run_cmd_advanced_training(
@ -623,6 +627,7 @@ def ti_tab(
seed,
caption_extension,
cache_latents,
optimizer,
) = gradio_training(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
@ -756,6 +761,7 @@ def ti_tab(
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,
]
button_open_config.click(