Adding support for Lion optimizer in gui
This commit is contained in:
parent
bb57c1a36e
commit
758bfe85dc
@ -89,6 +89,7 @@ def save_configuration(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -179,6 +180,7 @@ def open_configuration(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -253,6 +255,7 @@ def train_model(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -397,6 +400,7 @@ def train_model(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
caption_extension=caption_extension,
|
caption_extension=caption_extension,
|
||||||
cache_latents=cache_latents,
|
cache_latents=cache_latents,
|
||||||
|
optimizer=optimizer
|
||||||
)
|
)
|
||||||
|
|
||||||
run_cmd += run_cmd_advanced_training(
|
run_cmd += run_cmd_advanced_training(
|
||||||
@ -541,6 +545,7 @@ def dreambooth_tab(
|
|||||||
seed,
|
seed,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
|
optimizer,
|
||||||
) = gradio_training(
|
) = gradio_training(
|
||||||
learning_rate_value='1e-5',
|
learning_rate_value='1e-5',
|
||||||
lr_scheduler_value='cosine',
|
lr_scheduler_value='cosine',
|
||||||
@ -668,6 +673,7 @@ def dreambooth_tab(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -85,6 +85,7 @@ def save_configuration(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -181,6 +182,7 @@ def open_config_file(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -262,6 +264,7 @@ def train_model(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
# create caption json file
|
# create caption json file
|
||||||
if generate_caption_database:
|
if generate_caption_database:
|
||||||
@ -386,6 +389,7 @@ def train_model(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
caption_extension=caption_extension,
|
caption_extension=caption_extension,
|
||||||
cache_latents=cache_latents,
|
cache_latents=cache_latents,
|
||||||
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
run_cmd += run_cmd_advanced_training(
|
run_cmd += run_cmd_advanced_training(
|
||||||
@ -564,6 +568,7 @@ def finetune_tab():
|
|||||||
seed,
|
seed,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
|
optimizer,
|
||||||
) = gradio_training(learning_rate_value='1e-5')
|
) = gradio_training(learning_rate_value='1e-5')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
|
dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
|
||||||
@ -661,6 +666,7 @@ def finetune_tab():
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
@ -445,6 +445,7 @@ def gradio_training(
|
|||||||
value=2,
|
value=2,
|
||||||
)
|
)
|
||||||
seed = gr.Textbox(label='Seed', value=1234)
|
seed = gr.Textbox(label='Seed', value=1234)
|
||||||
|
cache_latents = gr.Checkbox(label='Cache latent', value=True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate = gr.Textbox(
|
learning_rate = gr.Textbox(
|
||||||
label='Learning rate', value=learning_rate_value
|
label='Learning rate', value=learning_rate_value
|
||||||
@ -464,7 +465,15 @@ def gradio_training(
|
|||||||
lr_warmup = gr.Textbox(
|
lr_warmup = gr.Textbox(
|
||||||
label='LR warmup (% of steps)', value=lr_warmup_value
|
label='LR warmup (% of steps)', value=lr_warmup_value
|
||||||
)
|
)
|
||||||
cache_latents = gr.Checkbox(label='Cache latent', value=True)
|
optimizer = gr.Dropdown(
|
||||||
|
label='Optimizer',
|
||||||
|
choices=[
|
||||||
|
'AdamW',
|
||||||
|
'Lion',
|
||||||
|
],
|
||||||
|
value="AdamW",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
@ -478,6 +487,7 @@ def gradio_training(
|
|||||||
seed,
|
seed,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
|
optimizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -512,10 +522,34 @@ def run_cmd_training(**kwargs):
|
|||||||
if kwargs.get('caption_extension')
|
if kwargs.get('caption_extension')
|
||||||
else '',
|
else '',
|
||||||
' --cache_latents' if kwargs.get('cache_latents') else '',
|
' --cache_latents' if kwargs.get('cache_latents') else '',
|
||||||
|
' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
|
||||||
]
|
]
|
||||||
run_cmd = ''.join(options)
|
run_cmd = ''.join(options)
|
||||||
return run_cmd
|
return run_cmd
|
||||||
|
|
||||||
|
# # This function takes a dictionary of keyword arguments and returns a string that can be used to run a command-line training script
|
||||||
|
# def run_cmd_training(**kwargs):
|
||||||
|
# arg_map = {
|
||||||
|
# 'learning_rate': ' --learning_rate="{}"',
|
||||||
|
# 'lr_scheduler': ' --lr_scheduler="{}"',
|
||||||
|
# 'lr_warmup_steps': ' --lr_warmup_steps="{}"',
|
||||||
|
# 'train_batch_size': ' --train_batch_size="{}"',
|
||||||
|
# 'max_train_steps': ' --max_train_steps="{}"',
|
||||||
|
# 'save_every_n_epochs': ' --save_every_n_epochs="{}"',
|
||||||
|
# 'mixed_precision': ' --mixed_precision="{}"',
|
||||||
|
# 'save_precision': ' --save_precision="{}"',
|
||||||
|
# 'seed': ' --seed="{}"',
|
||||||
|
# 'caption_extension': ' --caption_extension="{}"',
|
||||||
|
# 'cache_latents': ' --cache_latents',
|
||||||
|
# 'optimizer': ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
|
||||||
|
# }
|
||||||
|
|
||||||
|
# options = [arg_map[key].format(value) for key, value in kwargs.items() if key in arg_map and value]
|
||||||
|
|
||||||
|
# cmd = ''.join(options)
|
||||||
|
|
||||||
|
# return cmd
|
||||||
|
|
||||||
|
|
||||||
def gradio_advanced_training():
|
def gradio_advanced_training():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -664,3 +698,34 @@ def run_cmd_advanced_training(**kwargs):
|
|||||||
]
|
]
|
||||||
run_cmd = ''.join(options)
|
run_cmd = ''.join(options)
|
||||||
return run_cmd
|
return run_cmd
|
||||||
|
|
||||||
|
# def run_cmd_advanced_training(**kwargs):
|
||||||
|
# arg_map = {
|
||||||
|
# 'max_train_epochs': ' --max_train_epochs="{}"',
|
||||||
|
# 'max_data_loader_n_workers': ' --max_data_loader_n_workers="{}"',
|
||||||
|
# 'max_token_length': ' --max_token_length={}' if int(kwargs.get('max_token_length', 75)) > 75 else '',
|
||||||
|
# 'clip_skip': ' --clip_skip={}' if int(kwargs.get('clip_skip', 1)) > 1 else '',
|
||||||
|
# 'resume': ' --resume="{}"',
|
||||||
|
# 'keep_tokens': ' --keep_tokens="{}"' if int(kwargs.get('keep_tokens', 0)) > 0 else '',
|
||||||
|
# 'caption_dropout_every_n_epochs': ' --caption_dropout_every_n_epochs="{}"' if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 else '',
|
||||||
|
# 'caption_dropout_rate': ' --caption_dropout_rate="{}"' if float(kwargs.get('caption_dropout_rate', 0)) > 0 else '',
|
||||||
|
# 'bucket_reso_steps': ' --bucket_reso_steps={:d}' if int(kwargs.get('bucket_reso_steps', 64)) >= 1 else '',
|
||||||
|
# 'save_state': ' --save_state',
|
||||||
|
# 'mem_eff_attn': ' --mem_eff_attn',
|
||||||
|
# 'color_aug': ' --color_aug',
|
||||||
|
# 'flip_aug': ' --flip_aug',
|
||||||
|
# 'shuffle_caption': ' --shuffle_caption',
|
||||||
|
# 'gradient_checkpointing': ' --gradient_checkpointing',
|
||||||
|
# 'full_fp16': ' --full_fp16',
|
||||||
|
# 'xformers': ' --xformers',
|
||||||
|
# 'use_8bit_adam': ' --use_8bit_adam',
|
||||||
|
# 'persistent_data_loader_workers': ' --persistent_data_loader_workers',
|
||||||
|
# 'bucket_no_upscale': ' --bucket_no_upscale',
|
||||||
|
# 'random_crop': ' --random_crop',
|
||||||
|
# }
|
||||||
|
|
||||||
|
# options = [arg_map[key].format(value) for key, value in kwargs.items() if key in arg_map and value]
|
||||||
|
|
||||||
|
# cmd = ''.join(options)
|
||||||
|
|
||||||
|
# return cmd
|
@ -100,6 +100,7 @@ def save_configuration(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -197,6 +198,7 @@ def open_configuration(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -278,6 +280,7 @@ def train_model(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -457,6 +460,7 @@ def train_model(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
caption_extension=caption_extension,
|
caption_extension=caption_extension,
|
||||||
cache_latents=cache_latents,
|
cache_latents=cache_latents,
|
||||||
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
run_cmd += run_cmd_advanced_training(
|
run_cmd += run_cmd_advanced_training(
|
||||||
@ -609,6 +613,7 @@ def lora_tab(
|
|||||||
seed,
|
seed,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
|
optimizer,
|
||||||
) = gradio_training(
|
) = gradio_training(
|
||||||
learning_rate_value='0.0001',
|
learning_rate_value='0.0001',
|
||||||
lr_scheduler_value='cosine',
|
lr_scheduler_value='cosine',
|
||||||
@ -778,6 +783,7 @@ def lora_tab(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
59
presets/lion_optimizer.json
Normal file
59
presets/lion_optimizer.json
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
{
|
||||||
|
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
|
||||||
|
"v2": false,
|
||||||
|
"v_parameterization": false,
|
||||||
|
"logging_dir": "D:\\dataset\\marty_mcfly\\1985\\lora/log",
|
||||||
|
"train_data_dir": "D:\\dataset\\marty_mcfly\\1985\\lora\\img_gan",
|
||||||
|
"reg_data_dir": "",
|
||||||
|
"output_dir": "D:/lora/sd1.5/marty_mcfly",
|
||||||
|
"max_resolution": "512,512",
|
||||||
|
"learning_rate": "0.00003333",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"lr_warmup": "0",
|
||||||
|
"train_batch_size": 8,
|
||||||
|
"epoch": "1",
|
||||||
|
"save_every_n_epochs": "1",
|
||||||
|
"mixed_precision": "bf16",
|
||||||
|
"save_precision": "fp16",
|
||||||
|
"seed": "1234",
|
||||||
|
"num_cpu_threads_per_process": 2,
|
||||||
|
"cache_latents": false,
|
||||||
|
"caption_extension": "",
|
||||||
|
"enable_bucket": true,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"full_fp16": false,
|
||||||
|
"no_token_padding": false,
|
||||||
|
"stop_text_encoder_training": 0,
|
||||||
|
"use_8bit_adam": false,
|
||||||
|
"xformers": true,
|
||||||
|
"save_model_as": "safetensors",
|
||||||
|
"shuffle_caption": false,
|
||||||
|
"save_state": false,
|
||||||
|
"resume": "",
|
||||||
|
"prior_loss_weight": 1.0,
|
||||||
|
"text_encoder_lr": "0.000016666",
|
||||||
|
"unet_lr": "0.00003333",
|
||||||
|
"network_dim": 128,
|
||||||
|
"lora_network_weights": "",
|
||||||
|
"color_aug": false,
|
||||||
|
"flip_aug": false,
|
||||||
|
"clip_skip": "1",
|
||||||
|
"gradient_accumulation_steps": 1.0,
|
||||||
|
"mem_eff_attn": false,
|
||||||
|
"output_name": "mrtmcfl_v2.0",
|
||||||
|
"model_list": "runwayml/stable-diffusion-v1-5",
|
||||||
|
"max_token_length": "75",
|
||||||
|
"max_train_epochs": "",
|
||||||
|
"max_data_loader_n_workers": "0",
|
||||||
|
"network_alpha": 128,
|
||||||
|
"training_comment": "",
|
||||||
|
"keep_tokens": "0",
|
||||||
|
"lr_scheduler_num_cycles": "",
|
||||||
|
"lr_scheduler_power": "",
|
||||||
|
"persistent_data_loader_workers": false,
|
||||||
|
"bucket_no_upscale": true,
|
||||||
|
"random_crop": true,
|
||||||
|
"bucket_reso_steps": 64.0,
|
||||||
|
"caption_dropout_every_n_epochs": 0.0,
|
||||||
|
"caption_dropout_rate": 0.1
|
||||||
|
}
|
@ -13,6 +13,7 @@ gradio==3.16.2
|
|||||||
altair==4.2.2
|
altair==4.2.2
|
||||||
easygui==0.98.3
|
easygui==0.98.3
|
||||||
tk==0.1.0
|
tk==0.1.0
|
||||||
|
lion-pytorch==0.0.6
|
||||||
# for BLIP captioning
|
# for BLIP captioning
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
timm==0.6.12
|
timm==0.6.12
|
||||||
@ -21,6 +22,6 @@ fairscale==0.4.13
|
|||||||
# tensorflow<2.11
|
# tensorflow<2.11
|
||||||
tensorflow==2.10.1
|
tensorflow==2.10.1
|
||||||
huggingface-hub==0.12.0
|
huggingface-hub==0.12.0
|
||||||
xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
# xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
.
|
.
|
@ -95,6 +95,7 @@ def save_configuration(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -195,6 +196,7 @@ def open_configuration(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -275,6 +277,7 @@ def train_model(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -434,6 +437,7 @@ def train_model(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
caption_extension=caption_extension,
|
caption_extension=caption_extension,
|
||||||
cache_latents=cache_latents,
|
cache_latents=cache_latents,
|
||||||
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
run_cmd += run_cmd_advanced_training(
|
run_cmd += run_cmd_advanced_training(
|
||||||
@ -623,6 +627,7 @@ def ti_tab(
|
|||||||
seed,
|
seed,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
|
optimizer,
|
||||||
) = gradio_training(
|
) = gradio_training(
|
||||||
learning_rate_value='1e-5',
|
learning_rate_value='1e-5',
|
||||||
lr_scheduler_value='cosine',
|
lr_scheduler_value='cosine',
|
||||||
@ -756,6 +761,7 @@ def ti_tab(
|
|||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
Loading…
Reference in New Issue
Block a user