Move cxommon adv train parm to common gui
This commit is contained in:
parent
abccecb093
commit
123cf4e3c5
@ -77,7 +77,8 @@ def save_configuration(
|
|||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,mem_eff_attn,
|
||||||
|
gradient_accumulation_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -158,7 +159,8 @@ def open_configuration(
|
|||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,mem_eff_attn,
|
||||||
|
gradient_accumulation_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -223,7 +225,8 @@ def train_model(
|
|||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,mem_eff_attn,
|
||||||
|
gradient_accumulation_steps,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -322,24 +325,12 @@ def train_model(
|
|||||||
run_cmd += ' --cache_latents'
|
run_cmd += ' --cache_latents'
|
||||||
if enable_bucket:
|
if enable_bucket:
|
||||||
run_cmd += ' --enable_bucket'
|
run_cmd += ' --enable_bucket'
|
||||||
if gradient_checkpointing:
|
|
||||||
run_cmd += ' --gradient_checkpointing'
|
|
||||||
if full_fp16:
|
|
||||||
run_cmd += ' --full_fp16'
|
|
||||||
if no_token_padding:
|
if no_token_padding:
|
||||||
run_cmd += ' --no_token_padding'
|
run_cmd += ' --no_token_padding'
|
||||||
if use_8bit_adam:
|
if use_8bit_adam:
|
||||||
run_cmd += ' --use_8bit_adam'
|
run_cmd += ' --use_8bit_adam'
|
||||||
if xformers:
|
if xformers:
|
||||||
run_cmd += ' --xformers'
|
run_cmd += ' --xformers'
|
||||||
if shuffle_caption:
|
|
||||||
run_cmd += ' --shuffle_caption'
|
|
||||||
# if save_state:
|
|
||||||
# run_cmd += ' --save_state'
|
|
||||||
if color_aug:
|
|
||||||
run_cmd += ' --color_aug'
|
|
||||||
if flip_aug:
|
|
||||||
run_cmd += ' --flip_aug'
|
|
||||||
run_cmd += (
|
run_cmd += (
|
||||||
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
||||||
)
|
)
|
||||||
@ -353,8 +344,6 @@ def train_model(
|
|||||||
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
||||||
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
||||||
run_cmd += f' --max_train_steps={max_train_steps}'
|
run_cmd += f' --max_train_steps={max_train_steps}'
|
||||||
run_cmd += f' --use_8bit_adam'
|
|
||||||
run_cmd += f' --xformers'
|
|
||||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||||
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
|
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
|
||||||
run_cmd += f' --seed={seed}'
|
run_cmd += f' --seed={seed}'
|
||||||
@ -372,8 +361,6 @@ def train_model(
|
|||||||
# run_cmd += f' --resume={resume}'
|
# run_cmd += f' --resume={resume}'
|
||||||
if not float(prior_loss_weight) == 1.0:
|
if not float(prior_loss_weight) == 1.0:
|
||||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||||
if int(clip_skip) > 1:
|
|
||||||
run_cmd += f' --clip_skip={str(clip_skip)}'
|
|
||||||
if not vae == '':
|
if not vae == '':
|
||||||
run_cmd += f' --vae="{vae}"'
|
run_cmd += f' --vae="{vae}"'
|
||||||
if not output_name == '':
|
if not output_name == '':
|
||||||
@ -384,12 +371,23 @@ def train_model(
|
|||||||
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
||||||
if not max_data_loader_n_workers == '':
|
if not max_data_loader_n_workers == '':
|
||||||
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
||||||
|
if int(gradient_accumulation_steps) > 1:
|
||||||
|
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
||||||
run_cmd += run_cmd_advanced_training(
|
run_cmd += run_cmd_advanced_training(
|
||||||
max_train_epochs=max_train_epochs,
|
max_train_epochs=max_train_epochs,
|
||||||
max_data_loader_n_workers=max_data_loader_n_workers,
|
max_data_loader_n_workers=max_data_loader_n_workers,
|
||||||
max_token_length=max_token_length,
|
max_token_length=max_token_length,
|
||||||
resume=resume,
|
resume=resume,
|
||||||
save_state=save_state,
|
save_state=save_state,
|
||||||
|
mem_eff_attn=mem_eff_attn,
|
||||||
|
clip_skip=clip_skip,
|
||||||
|
flip_aug=flip_aug,
|
||||||
|
color_aug=color_aug,
|
||||||
|
shuffle_caption=shuffle_caption,
|
||||||
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
|
full_fp16=full_fp16,
|
||||||
|
xformers=xformers,
|
||||||
|
use_8bit_adam=use_8bit_adam,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -668,36 +666,13 @@ def dreambooth_tab(
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
||||||
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
||||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
|
||||||
with gr.Accordion('Advanced Configuration', open=False):
|
with gr.Accordion('Advanced Configuration', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
full_fp16 = gr.Checkbox(
|
|
||||||
label='Full fp16 training (experimental)', value=False
|
|
||||||
)
|
|
||||||
no_token_padding = gr.Checkbox(
|
no_token_padding = gr.Checkbox(
|
||||||
label='No token padding', value=False
|
label='No token padding', value=False
|
||||||
)
|
)
|
||||||
|
gradient_accumulation_steps = gr.Number(
|
||||||
gradient_checkpointing = gr.Checkbox(
|
label='Gradient accumulate steps', value='1'
|
||||||
label='Gradient checkpointing', value=False
|
|
||||||
)
|
|
||||||
|
|
||||||
shuffle_caption = gr.Checkbox(
|
|
||||||
label='Shuffle caption', value=False
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
color_aug = gr.Checkbox(
|
|
||||||
label='Color augmentation', value=False
|
|
||||||
)
|
|
||||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
|
||||||
color_aug.change(
|
|
||||||
color_aug_changed,
|
|
||||||
inputs=[color_aug],
|
|
||||||
outputs=[cache_latent],
|
|
||||||
)
|
|
||||||
clip_skip = gr.Slider(
|
|
||||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prior_loss_weight = gr.Number(
|
prior_loss_weight = gr.Number(
|
||||||
@ -709,7 +684,27 @@ def dreambooth_tab(
|
|||||||
)
|
)
|
||||||
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||||
vae_button.click(get_any_file_path, outputs=vae)
|
vae_button.click(get_any_file_path, outputs=vae)
|
||||||
save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
(
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
|
full_fp16,
|
||||||
|
gradient_checkpointing,
|
||||||
|
shuffle_caption,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
clip_skip,
|
||||||
|
mem_eff_attn,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
|
) = gradio_advanced_training()
|
||||||
|
color_aug.change(
|
||||||
|
color_aug_changed,
|
||||||
|
inputs=[color_aug],
|
||||||
|
outputs=[cache_latent],
|
||||||
|
)
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'This section provide Dreambooth tools to help setup your dataset...'
|
'This section provide Dreambooth tools to help setup your dataset...'
|
||||||
@ -763,7 +758,8 @@ def dreambooth_tab(
|
|||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,mem_eff_attn,
|
||||||
|
gradient_accumulation_steps,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -13,7 +13,8 @@ from library.common_gui import (
|
|||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
save_inference_file,
|
save_inference_file,
|
||||||
set_pretrained_model_name_or_path_input,
|
set_pretrained_model_name_or_path_input,
|
||||||
gradio_advanced_training,run_cmd_advanced_training
|
gradio_advanced_training,run_cmd_advanced_training,
|
||||||
|
color_aug_changed,
|
||||||
)
|
)
|
||||||
from library.utilities import utilities_tab
|
from library.utilities import utilities_tab
|
||||||
|
|
||||||
@ -69,7 +70,7 @@ def save_configuration(
|
|||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,full_fp16,color_aug,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -154,7 +155,7 @@ def open_config_file(
|
|||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,full_fp16,color_aug,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -224,7 +225,7 @@ def train_model(
|
|||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,full_fp16,color_aug,
|
||||||
):
|
):
|
||||||
# create caption json file
|
# create caption json file
|
||||||
if generate_caption_database:
|
if generate_caption_database:
|
||||||
@ -262,8 +263,8 @@ def train_model(
|
|||||||
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
||||||
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
||||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||||
if flip_aug:
|
# if flip_aug:
|
||||||
run_cmd += f' --flip_aug'
|
# run_cmd += f' --flip_aug'
|
||||||
if full_path:
|
if full_path:
|
||||||
run_cmd += f' --full_path'
|
run_cmd += f' --full_path'
|
||||||
|
|
||||||
@ -301,16 +302,6 @@ def train_model(
|
|||||||
run_cmd += ' --v_parameterization'
|
run_cmd += ' --v_parameterization'
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
run_cmd += ' --train_text_encoder'
|
run_cmd += ' --train_text_encoder'
|
||||||
if use_8bit_adam:
|
|
||||||
run_cmd += f' --use_8bit_adam'
|
|
||||||
if xformers:
|
|
||||||
run_cmd += f' --xformers'
|
|
||||||
if gradient_checkpointing:
|
|
||||||
run_cmd += ' --gradient_checkpointing'
|
|
||||||
if mem_eff_attn:
|
|
||||||
run_cmd += ' --mem_eff_attn'
|
|
||||||
if shuffle_caption:
|
|
||||||
run_cmd += ' --shuffle_caption'
|
|
||||||
run_cmd += (
|
run_cmd += (
|
||||||
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
||||||
)
|
)
|
||||||
@ -331,8 +322,6 @@ def train_model(
|
|||||||
run_cmd += f' --save_precision={save_precision}'
|
run_cmd += f' --save_precision={save_precision}'
|
||||||
if not save_model_as == 'same as source model':
|
if not save_model_as == 'same as source model':
|
||||||
run_cmd += f' --save_model_as={save_model_as}'
|
run_cmd += f' --save_model_as={save_model_as}'
|
||||||
if int(clip_skip) > 1:
|
|
||||||
run_cmd += f' --clip_skip={str(clip_skip)}'
|
|
||||||
if int(gradient_accumulation_steps) > 1:
|
if int(gradient_accumulation_steps) > 1:
|
||||||
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
||||||
# if save_state:
|
# if save_state:
|
||||||
@ -349,6 +338,15 @@ def train_model(
|
|||||||
max_token_length=max_token_length,
|
max_token_length=max_token_length,
|
||||||
resume=resume,
|
resume=resume,
|
||||||
save_state=save_state,
|
save_state=save_state,
|
||||||
|
mem_eff_attn=mem_eff_attn,
|
||||||
|
clip_skip=clip_skip,
|
||||||
|
flip_aug=flip_aug,
|
||||||
|
color_aug=color_aug,
|
||||||
|
shuffle_caption=shuffle_caption,
|
||||||
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
|
full_fp16=full_fp16,
|
||||||
|
xformers=xformers,
|
||||||
|
use_8bit_adam=use_8bit_adam,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -565,7 +563,6 @@ def finetune_tab():
|
|||||||
label='Latent metadata filename', value='meta_lat.json'
|
label='Latent metadata filename', value='meta_lat.json'
|
||||||
)
|
)
|
||||||
full_path = gr.Checkbox(label='Use full path', value=True)
|
full_path = gr.Checkbox(label='Use full path', value=True)
|
||||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
|
||||||
with gr.Tab('Training parameters'):
|
with gr.Tab('Training parameters'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
||||||
@ -634,25 +631,30 @@ def finetune_tab():
|
|||||||
)
|
)
|
||||||
with gr.Accordion('Advanced parameters', open=False):
|
with gr.Accordion('Advanced parameters', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
|
||||||
clip_skip = gr.Slider(
|
|
||||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
|
||||||
)
|
|
||||||
mem_eff_attn = gr.Checkbox(
|
|
||||||
label='Memory efficient attention', value=False
|
|
||||||
)
|
|
||||||
shuffle_caption = gr.Checkbox(
|
|
||||||
label='Shuffle caption', value=False
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
gradient_checkpointing = gr.Checkbox(
|
|
||||||
label='Gradient checkpointing', value=False
|
|
||||||
)
|
|
||||||
gradient_accumulation_steps = gr.Number(
|
gradient_accumulation_steps = gr.Number(
|
||||||
label='Gradient accumulate steps', value='1'
|
label='Gradient accumulate steps', value='1'
|
||||||
)
|
)
|
||||||
save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
(
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
|
full_fp16,
|
||||||
|
gradient_checkpointing,
|
||||||
|
shuffle_caption,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
clip_skip,
|
||||||
|
mem_eff_attn,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
|
) = gradio_advanced_training()
|
||||||
|
# color_aug.change(
|
||||||
|
# color_aug_changed,
|
||||||
|
# inputs=[color_aug],
|
||||||
|
# # outputs=[cache_latent], # Not applicable to fine_tune.py
|
||||||
|
# )
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
create_caption = gr.Checkbox(
|
create_caption = gr.Checkbox(
|
||||||
@ -708,7 +710,7 @@ def finetune_tab():
|
|||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,full_fp16,color_aug,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
@ -304,13 +304,28 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
|||||||
|
|
||||||
def gradio_advanced_training():
|
def gradio_advanced_training():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
full_fp16 = gr.Checkbox(
|
||||||
resume = gr.Textbox(
|
label='Full fp16 training (experimental)', value=False
|
||||||
label='Resume from saved training state',
|
)
|
||||||
placeholder='path to "last-state" state folder to resume from',
|
gradient_checkpointing = gr.Checkbox(
|
||||||
|
label='Gradient checkpointing', value=False
|
||||||
|
)
|
||||||
|
shuffle_caption = gr.Checkbox(
|
||||||
|
label='Shuffle caption', value=False
|
||||||
|
)
|
||||||
|
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||||
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
|
with gr.Row():
|
||||||
|
color_aug = gr.Checkbox(
|
||||||
|
label='Color augmentation', value=False
|
||||||
|
)
|
||||||
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
|
clip_skip = gr.Slider(
|
||||||
|
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||||
|
)
|
||||||
|
mem_eff_attn = gr.Checkbox(
|
||||||
|
label='Memory efficient attention', value=False
|
||||||
)
|
)
|
||||||
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
|
||||||
resume_button.click(get_folder_path, outputs=resume)
|
|
||||||
max_token_length = gr.Dropdown(
|
max_token_length = gr.Dropdown(
|
||||||
label='Max Token Length',
|
label='Max Token Length',
|
||||||
choices=[
|
choices=[
|
||||||
@ -321,6 +336,13 @@ def gradio_advanced_training():
|
|||||||
value='75',
|
value='75',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||||
|
resume = gr.Textbox(
|
||||||
|
label='Resume from saved training state',
|
||||||
|
placeholder='path to "last-state" state folder to resume from',
|
||||||
|
)
|
||||||
|
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
||||||
|
resume_button.click(get_folder_path, outputs=resume)
|
||||||
max_train_epochs = gr.Textbox(
|
max_train_epochs = gr.Textbox(
|
||||||
label='Max train epoch',
|
label='Max train epoch',
|
||||||
placeholder='(Optional) Override number of epoch',
|
placeholder='(Optional) Override number of epoch',
|
||||||
@ -330,6 +352,15 @@ def gradio_advanced_training():
|
|||||||
placeholder='(Optional) Override number of epoch. Default: 8',
|
placeholder='(Optional) Override number of epoch. Default: 8',
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
|
full_fp16,
|
||||||
|
gradient_checkpointing,
|
||||||
|
shuffle_caption,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
clip_skip,
|
||||||
|
mem_eff_attn,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
@ -343,16 +374,41 @@ def run_cmd_advanced_training(**kwargs):
|
|||||||
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
|
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
|
||||||
if kwargs.get('max_train_epochs')
|
if kwargs.get('max_train_epochs')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
|
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
|
||||||
if kwargs.get('max_data_loader_n_workers')
|
if kwargs.get('max_data_loader_n_workers')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --max_token_length={kwargs.get("max_token_length", "")}'
|
f' --max_token_length={kwargs.get("max_token_length", "")}'
|
||||||
if int(kwargs.get('max_token_length', 0)) > 75
|
if int(kwargs.get('max_token_length', 75)) > 75
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
|
f' --clip_skip={kwargs.get("clip_skip", "")}'
|
||||||
|
if int(kwargs.get('clip_skip', 1)) > 1
|
||||||
|
else '',
|
||||||
|
|
||||||
f' --resume="{kwargs.get("resume", "")}"'
|
f' --resume="{kwargs.get("resume", "")}"'
|
||||||
if kwargs.get('resume')
|
if kwargs.get('resume')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
' --save_state' if kwargs.get('save_state') else '',
|
' --save_state' if kwargs.get('save_state') else '',
|
||||||
|
|
||||||
|
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
||||||
|
|
||||||
|
' --color_aug' if kwargs.get('color_aug') else '',
|
||||||
|
|
||||||
|
' --flip_aug' if kwargs.get('flip_aug') else '',
|
||||||
|
|
||||||
|
' --shuffle_caption' if kwargs.get('shuffle_caption') else '',
|
||||||
|
|
||||||
|
' --gradient_checkpointing' if kwargs.get('gradient_checkpointing') else '',
|
||||||
|
|
||||||
|
' --full_fp16' if kwargs.get('full_fp16') else '',
|
||||||
|
|
||||||
|
' --xformers' if kwargs.get('xformers') else '',
|
||||||
|
|
||||||
|
' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
|
||||||
|
|
||||||
]
|
]
|
||||||
run_cmd = ''.join(options)
|
run_cmd = ''.join(options)
|
||||||
return run_cmd
|
return run_cmd
|
||||||
|
71
lora_gui.py
71
lora_gui.py
@ -340,26 +340,8 @@ def train_model(
|
|||||||
run_cmd += ' --cache_latents'
|
run_cmd += ' --cache_latents'
|
||||||
if enable_bucket:
|
if enable_bucket:
|
||||||
run_cmd += ' --enable_bucket'
|
run_cmd += ' --enable_bucket'
|
||||||
if gradient_checkpointing:
|
|
||||||
run_cmd += ' --gradient_checkpointing'
|
|
||||||
if full_fp16:
|
|
||||||
run_cmd += ' --full_fp16'
|
|
||||||
if no_token_padding:
|
if no_token_padding:
|
||||||
run_cmd += ' --no_token_padding'
|
run_cmd += ' --no_token_padding'
|
||||||
if use_8bit_adam:
|
|
||||||
run_cmd += ' --use_8bit_adam'
|
|
||||||
if xformers:
|
|
||||||
run_cmd += ' --xformers'
|
|
||||||
if shuffle_caption:
|
|
||||||
run_cmd += ' --shuffle_caption'
|
|
||||||
# if save_state:
|
|
||||||
# run_cmd += ' --save_state'
|
|
||||||
if color_aug:
|
|
||||||
run_cmd += ' --color_aug'
|
|
||||||
if flip_aug:
|
|
||||||
run_cmd += ' --flip_aug'
|
|
||||||
if mem_eff_attn:
|
|
||||||
run_cmd += ' --mem_eff_attn'
|
|
||||||
run_cmd += (
|
run_cmd += (
|
||||||
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
||||||
)
|
)
|
||||||
@ -408,8 +390,6 @@ def train_model(
|
|||||||
run_cmd += f' --network_dim={network_dim}'
|
run_cmd += f' --network_dim={network_dim}'
|
||||||
if not lora_network_weights == '':
|
if not lora_network_weights == '':
|
||||||
run_cmd += f' --network_weights="{lora_network_weights}"'
|
run_cmd += f' --network_weights="{lora_network_weights}"'
|
||||||
if int(clip_skip) > 1:
|
|
||||||
run_cmd += f' --clip_skip={str(clip_skip)}'
|
|
||||||
if int(gradient_accumulation_steps) > 1:
|
if int(gradient_accumulation_steps) > 1:
|
||||||
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
||||||
# if not vae == '':
|
# if not vae == '':
|
||||||
@ -424,6 +404,15 @@ def train_model(
|
|||||||
max_token_length=max_token_length,
|
max_token_length=max_token_length,
|
||||||
resume=resume,
|
resume=resume,
|
||||||
save_state=save_state,
|
save_state=save_state,
|
||||||
|
mem_eff_attn=mem_eff_attn,
|
||||||
|
clip_skip=clip_skip,
|
||||||
|
flip_aug=flip_aug,
|
||||||
|
color_aug=color_aug,
|
||||||
|
shuffle_caption=shuffle_caption,
|
||||||
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
|
full_fp16=full_fp16,
|
||||||
|
xformers=xformers,
|
||||||
|
use_8bit_adam=use_8bit_adam,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -729,53 +718,39 @@ def lora_tab(
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
||||||
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
||||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
|
||||||
with gr.Accordion('Advanced Configuration', open=False):
|
with gr.Accordion('Advanced Configuration', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
full_fp16 = gr.Checkbox(
|
|
||||||
label='Full fp16 training (experimental)', value=False
|
|
||||||
)
|
|
||||||
no_token_padding = gr.Checkbox(
|
no_token_padding = gr.Checkbox(
|
||||||
label='No token padding', value=False
|
label='No token padding', value=False
|
||||||
)
|
)
|
||||||
|
|
||||||
gradient_checkpointing = gr.Checkbox(
|
|
||||||
label='Gradient checkpointing', value=False
|
|
||||||
)
|
|
||||||
gradient_accumulation_steps = gr.Number(
|
gradient_accumulation_steps = gr.Number(
|
||||||
label='Gradient accumulate steps', value='1'
|
label='Gradient accumulate steps', value='1'
|
||||||
)
|
)
|
||||||
|
|
||||||
shuffle_caption = gr.Checkbox(
|
|
||||||
label='Shuffle caption', value=False
|
|
||||||
)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prior_loss_weight = gr.Number(
|
prior_loss_weight = gr.Number(
|
||||||
label='Prior loss weight', value=1.0
|
label='Prior loss weight', value=1.0
|
||||||
)
|
)
|
||||||
color_aug = gr.Checkbox(
|
|
||||||
label='Color augmentation', value=False
|
|
||||||
)
|
|
||||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
|
||||||
color_aug.change(
|
|
||||||
color_aug_changed,
|
|
||||||
inputs=[color_aug],
|
|
||||||
outputs=[cache_latent],
|
|
||||||
)
|
|
||||||
clip_skip = gr.Slider(
|
|
||||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
|
||||||
)
|
|
||||||
mem_eff_attn = gr.Checkbox(
|
|
||||||
label='Memory efficient attention', value=False
|
|
||||||
)
|
|
||||||
(
|
(
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
|
full_fp16,
|
||||||
|
gradient_checkpointing,
|
||||||
|
shuffle_caption,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
clip_skip,
|
||||||
|
mem_eff_attn,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
|
color_aug.change(
|
||||||
|
color_aug_changed,
|
||||||
|
inputs=[color_aug],
|
||||||
|
outputs=[cache_latent],
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
|
Loading…
Reference in New Issue
Block a user