Adding support for Lion optimizer in gui
This commit is contained in:
parent
bb57c1a36e
commit
758bfe85dc
@ -89,6 +89,7 @@ def save_configuration(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -179,6 +180,7 @@ def open_configuration(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -253,6 +255,7 @@ def train_model(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -397,6 +400,7 @@ def train_model(
|
||||
seed=seed,
|
||||
caption_extension=caption_extension,
|
||||
cache_latents=cache_latents,
|
||||
optimizer=optimizer
|
||||
)
|
||||
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
@ -541,6 +545,7 @@ def dreambooth_tab(
|
||||
seed,
|
||||
caption_extension,
|
||||
cache_latents,
|
||||
optimizer,
|
||||
) = gradio_training(
|
||||
learning_rate_value='1e-5',
|
||||
lr_scheduler_value='cosine',
|
||||
@ -668,6 +673,7 @@ def dreambooth_tab(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
@ -85,6 +85,7 @@ def save_configuration(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -181,6 +182,7 @@ def open_config_file(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -262,6 +264,7 @@ def train_model(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
# create caption json file
|
||||
if generate_caption_database:
|
||||
@ -386,6 +389,7 @@ def train_model(
|
||||
seed=seed,
|
||||
caption_extension=caption_extension,
|
||||
cache_latents=cache_latents,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
@ -564,6 +568,7 @@ def finetune_tab():
|
||||
seed,
|
||||
caption_extension,
|
||||
cache_latents,
|
||||
optimizer,
|
||||
) = gradio_training(learning_rate_value='1e-5')
|
||||
with gr.Row():
|
||||
dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
|
||||
@ -661,6 +666,7 @@ def finetune_tab():
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
]
|
||||
|
||||
button_run.click(train_model, inputs=settings_list)
|
||||
|
@ -445,6 +445,7 @@ def gradio_training(
|
||||
value=2,
|
||||
)
|
||||
seed = gr.Textbox(label='Seed', value=1234)
|
||||
cache_latents = gr.Checkbox(label='Cache latent', value=True)
|
||||
with gr.Row():
|
||||
learning_rate = gr.Textbox(
|
||||
label='Learning rate', value=learning_rate_value
|
||||
@ -464,7 +465,15 @@ def gradio_training(
|
||||
lr_warmup = gr.Textbox(
|
||||
label='LR warmup (% of steps)', value=lr_warmup_value
|
||||
)
|
||||
cache_latents = gr.Checkbox(label='Cache latent', value=True)
|
||||
optimizer = gr.Dropdown(
|
||||
label='Optimizer',
|
||||
choices=[
|
||||
'AdamW',
|
||||
'Lion',
|
||||
],
|
||||
value="AdamW",
|
||||
interactive=True,
|
||||
)
|
||||
return (
|
||||
learning_rate,
|
||||
lr_scheduler,
|
||||
@ -478,6 +487,7 @@ def gradio_training(
|
||||
seed,
|
||||
caption_extension,
|
||||
cache_latents,
|
||||
optimizer,
|
||||
)
|
||||
|
||||
|
||||
@ -512,10 +522,34 @@ def run_cmd_training(**kwargs):
|
||||
if kwargs.get('caption_extension')
|
||||
else '',
|
||||
' --cache_latents' if kwargs.get('cache_latents') else '',
|
||||
' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
|
||||
]
|
||||
run_cmd = ''.join(options)
|
||||
return run_cmd
|
||||
|
||||
# # This function takes a dictionary of keyword arguments and returns a string that can be used to run a command-line training script
|
||||
# def run_cmd_training(**kwargs):
|
||||
# arg_map = {
|
||||
# 'learning_rate': ' --learning_rate="{}"',
|
||||
# 'lr_scheduler': ' --lr_scheduler="{}"',
|
||||
# 'lr_warmup_steps': ' --lr_warmup_steps="{}"',
|
||||
# 'train_batch_size': ' --train_batch_size="{}"',
|
||||
# 'max_train_steps': ' --max_train_steps="{}"',
|
||||
# 'save_every_n_epochs': ' --save_every_n_epochs="{}"',
|
||||
# 'mixed_precision': ' --mixed_precision="{}"',
|
||||
# 'save_precision': ' --save_precision="{}"',
|
||||
# 'seed': ' --seed="{}"',
|
||||
# 'caption_extension': ' --caption_extension="{}"',
|
||||
# 'cache_latents': ' --cache_latents',
|
||||
# 'optimizer': ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
|
||||
# }
|
||||
|
||||
# options = [arg_map[key].format(value) for key, value in kwargs.items() if key in arg_map and value]
|
||||
|
||||
# cmd = ''.join(options)
|
||||
|
||||
# return cmd
|
||||
|
||||
|
||||
def gradio_advanced_training():
|
||||
with gr.Row():
|
||||
@ -664,3 +698,34 @@ def run_cmd_advanced_training(**kwargs):
|
||||
]
|
||||
run_cmd = ''.join(options)
|
||||
return run_cmd
|
||||
|
||||
# def run_cmd_advanced_training(**kwargs):
|
||||
# arg_map = {
|
||||
# 'max_train_epochs': ' --max_train_epochs="{}"',
|
||||
# 'max_data_loader_n_workers': ' --max_data_loader_n_workers="{}"',
|
||||
# 'max_token_length': ' --max_token_length={}' if int(kwargs.get('max_token_length', 75)) > 75 else '',
|
||||
# 'clip_skip': ' --clip_skip={}' if int(kwargs.get('clip_skip', 1)) > 1 else '',
|
||||
# 'resume': ' --resume="{}"',
|
||||
# 'keep_tokens': ' --keep_tokens="{}"' if int(kwargs.get('keep_tokens', 0)) > 0 else '',
|
||||
# 'caption_dropout_every_n_epochs': ' --caption_dropout_every_n_epochs="{}"' if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 else '',
|
||||
# 'caption_dropout_rate': ' --caption_dropout_rate="{}"' if float(kwargs.get('caption_dropout_rate', 0)) > 0 else '',
|
||||
# 'bucket_reso_steps': ' --bucket_reso_steps={:d}' if int(kwargs.get('bucket_reso_steps', 64)) >= 1 else '',
|
||||
# 'save_state': ' --save_state',
|
||||
# 'mem_eff_attn': ' --mem_eff_attn',
|
||||
# 'color_aug': ' --color_aug',
|
||||
# 'flip_aug': ' --flip_aug',
|
||||
# 'shuffle_caption': ' --shuffle_caption',
|
||||
# 'gradient_checkpointing': ' --gradient_checkpointing',
|
||||
# 'full_fp16': ' --full_fp16',
|
||||
# 'xformers': ' --xformers',
|
||||
# 'use_8bit_adam': ' --use_8bit_adam',
|
||||
# 'persistent_data_loader_workers': ' --persistent_data_loader_workers',
|
||||
# 'bucket_no_upscale': ' --bucket_no_upscale',
|
||||
# 'random_crop': ' --random_crop',
|
||||
# }
|
||||
|
||||
# options = [arg_map[key].format(value) for key, value in kwargs.items() if key in arg_map and value]
|
||||
|
||||
# cmd = ''.join(options)
|
||||
|
||||
# return cmd
|
@ -100,6 +100,7 @@ def save_configuration(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -197,6 +198,7 @@ def open_configuration(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -278,6 +280,7 @@ def train_model(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -457,6 +460,7 @@ def train_model(
|
||||
seed=seed,
|
||||
caption_extension=caption_extension,
|
||||
cache_latents=cache_latents,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
@ -609,6 +613,7 @@ def lora_tab(
|
||||
seed,
|
||||
caption_extension,
|
||||
cache_latents,
|
||||
optimizer,
|
||||
) = gradio_training(
|
||||
learning_rate_value='0.0001',
|
||||
lr_scheduler_value='cosine',
|
||||
@ -778,6 +783,7 @@ def lora_tab(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
59
presets/lion_optimizer.json
Normal file
59
presets/lion_optimizer.json
Normal file
@ -0,0 +1,59 @@
|
||||
{
|
||||
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
|
||||
"v2": false,
|
||||
"v_parameterization": false,
|
||||
"logging_dir": "D:\\dataset\\marty_mcfly\\1985\\lora/log",
|
||||
"train_data_dir": "D:\\dataset\\marty_mcfly\\1985\\lora\\img_gan",
|
||||
"reg_data_dir": "",
|
||||
"output_dir": "D:/lora/sd1.5/marty_mcfly",
|
||||
"max_resolution": "512,512",
|
||||
"learning_rate": "0.00003333",
|
||||
"lr_scheduler": "cosine",
|
||||
"lr_warmup": "0",
|
||||
"train_batch_size": 8,
|
||||
"epoch": "1",
|
||||
"save_every_n_epochs": "1",
|
||||
"mixed_precision": "bf16",
|
||||
"save_precision": "fp16",
|
||||
"seed": "1234",
|
||||
"num_cpu_threads_per_process": 2,
|
||||
"cache_latents": false,
|
||||
"caption_extension": "",
|
||||
"enable_bucket": true,
|
||||
"gradient_checkpointing": false,
|
||||
"full_fp16": false,
|
||||
"no_token_padding": false,
|
||||
"stop_text_encoder_training": 0,
|
||||
"use_8bit_adam": false,
|
||||
"xformers": true,
|
||||
"save_model_as": "safetensors",
|
||||
"shuffle_caption": false,
|
||||
"save_state": false,
|
||||
"resume": "",
|
||||
"prior_loss_weight": 1.0,
|
||||
"text_encoder_lr": "0.000016666",
|
||||
"unet_lr": "0.00003333",
|
||||
"network_dim": 128,
|
||||
"lora_network_weights": "",
|
||||
"color_aug": false,
|
||||
"flip_aug": false,
|
||||
"clip_skip": "1",
|
||||
"gradient_accumulation_steps": 1.0,
|
||||
"mem_eff_attn": false,
|
||||
"output_name": "mrtmcfl_v2.0",
|
||||
"model_list": "runwayml/stable-diffusion-v1-5",
|
||||
"max_token_length": "75",
|
||||
"max_train_epochs": "",
|
||||
"max_data_loader_n_workers": "0",
|
||||
"network_alpha": 128,
|
||||
"training_comment": "",
|
||||
"keep_tokens": "0",
|
||||
"lr_scheduler_num_cycles": "",
|
||||
"lr_scheduler_power": "",
|
||||
"persistent_data_loader_workers": false,
|
||||
"bucket_no_upscale": true,
|
||||
"random_crop": true,
|
||||
"bucket_reso_steps": 64.0,
|
||||
"caption_dropout_every_n_epochs": 0.0,
|
||||
"caption_dropout_rate": 0.1
|
||||
}
|
@ -13,6 +13,7 @@ gradio==3.16.2
|
||||
altair==4.2.2
|
||||
easygui==0.98.3
|
||||
tk==0.1.0
|
||||
lion-pytorch==0.0.6
|
||||
# for BLIP captioning
|
||||
requests==2.28.2
|
||||
timm==0.6.12
|
||||
@ -21,6 +22,6 @@ fairscale==0.4.13
|
||||
# tensorflow<2.11
|
||||
tensorflow==2.10.1
|
||||
huggingface-hub==0.12.0
|
||||
xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
# xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
# for kohya_ss library
|
||||
.
|
@ -95,6 +95,7 @@ def save_configuration(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -195,6 +196,7 @@ def open_configuration(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -275,6 +277,7 @@ def train_model(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -434,6 +437,7 @@ def train_model(
|
||||
seed=seed,
|
||||
caption_extension=caption_extension,
|
||||
cache_latents=cache_latents,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
@ -623,6 +627,7 @@ def ti_tab(
|
||||
seed,
|
||||
caption_extension,
|
||||
cache_latents,
|
||||
optimizer,
|
||||
) = gradio_training(
|
||||
learning_rate_value='1e-5',
|
||||
lr_scheduler_value='cosine',
|
||||
@ -756,6 +761,7 @@ def ti_tab(
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
Loading…
Reference in New Issue
Block a user