refactor Dreambooth gui code

This commit is contained in:
bmaltais 2023-01-09 08:08:47 -05:00
parent 442eb7a292
commit 402cb51ec0

View File

@ -536,10 +536,10 @@ def UI(username, password):
def dreambooth_tab( def dreambooth_tab(
train_data_dir_input=gr.Textbox(), train_data_dir=gr.Textbox(),
reg_data_dir_input=gr.Textbox(), reg_data_dir=gr.Textbox(),
output_dir_input=gr.Textbox(), output_dir=gr.Textbox(),
logging_dir_input=gr.Textbox(), logging_dir=gr.Textbox(),
): ):
dummy_db_true = gr.Label(value=True, visible=False) dummy_db_true = gr.Label(value=True, visible=False)
dummy_db_false = gr.Label(value=False, visible=False) dummy_db_false = gr.Label(value=False, visible=False)
@ -564,24 +564,24 @@ def dreambooth_tab(
with gr.Tab('Source model'): with gr.Tab('Source model'):
# Define the input elements # Define the input elements
with gr.Row(): with gr.Row():
pretrained_model_name_or_path_input = gr.Textbox( pretrained_model_name_or_path = gr.Textbox(
label='Pretrained model name or path', label='Pretrained model name or path',
placeholder='enter the path to custom model or name of pretrained model', placeholder='enter the path to custom model or name of pretrained model',
) )
pretrained_model_name_or_path_fille = gr.Button( pretrained_model_name_or_path_file = gr.Button(
document_symbol, elem_id='open_folder_small' document_symbol, elem_id='open_folder_small'
) )
pretrained_model_name_or_path_fille.click( pretrained_model_name_or_path_file.click(
get_any_file_path, get_any_file_path,
inputs=[pretrained_model_name_or_path_input], inputs=[pretrained_model_name_or_path],
outputs=pretrained_model_name_or_path_input, outputs=pretrained_model_name_or_path,
) )
pretrained_model_name_or_path_folder = gr.Button( pretrained_model_name_or_path_folder = gr.Button(
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
pretrained_model_name_or_path_folder.click( pretrained_model_name_or_path_folder.click(
get_folder_path, get_folder_path,
outputs=pretrained_model_name_or_path_input, outputs=pretrained_model_name_or_path,
) )
model_list = gr.Dropdown( model_list = gr.Dropdown(
label='(Optional) Model Quick Pick', label='(Optional) Model Quick Pick',
@ -595,7 +595,7 @@ def dreambooth_tab(
'CompVis/stable-diffusion-v1-4', 'CompVis/stable-diffusion-v1-4',
], ],
) )
save_model_as_dropdown = gr.Dropdown( save_model_as = gr.Dropdown(
label='Save trained model as', label='Save trained model as',
choices=[ choices=[
'same as source model', 'same as source model',
@ -607,28 +607,28 @@ def dreambooth_tab(
value='same as source model', value='same as source model',
) )
with gr.Row(): with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True) v2 = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox( v_parameterization = gr.Checkbox(
label='v_parameterization', value=False label='v_parameterization', value=False
) )
pretrained_model_name_or_path_input.change( pretrained_model_name_or_path.change(
remove_doublequote, remove_doublequote,
inputs=[pretrained_model_name_or_path_input], inputs=[pretrained_model_name_or_path],
outputs=[pretrained_model_name_or_path_input], outputs=[pretrained_model_name_or_path],
) )
model_list.change( model_list.change(
set_pretrained_model_name_or_path_input, set_pretrained_model_name_or_path_input,
inputs=[model_list, v2_input, v_parameterization_input], inputs=[model_list, v2, v_parameterization],
outputs=[ outputs=[
pretrained_model_name_or_path_input, pretrained_model_name_or_path,
v2_input, v2,
v_parameterization_input, v_parameterization,
], ],
) )
with gr.Tab('Folders'): with gr.Tab('Folders'):
with gr.Row(): with gr.Row():
train_data_dir_input = gr.Textbox( train_data_dir = gr.Textbox(
label='Image folder', label='Image folder',
placeholder='Folder where the training folders containing the images are located', placeholder='Folder where the training folders containing the images are located',
) )
@ -636,9 +636,9 @@ def dreambooth_tab(
'📂', elem_id='open_folder_small' '📂', elem_id='open_folder_small'
) )
train_data_dir_input_folder.click( train_data_dir_input_folder.click(
get_folder_path, outputs=train_data_dir_input get_folder_path, outputs=train_data_dir
) )
reg_data_dir_input = gr.Textbox( reg_data_dir = gr.Textbox(
label='Regularisation folder', label='Regularisation folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located', placeholder='(Optional) Folder where where the regularization folders containing the images are located',
) )
@ -646,20 +646,20 @@ def dreambooth_tab(
'📂', elem_id='open_folder_small' '📂', elem_id='open_folder_small'
) )
reg_data_dir_input_folder.click( reg_data_dir_input_folder.click(
get_folder_path, outputs=reg_data_dir_input get_folder_path, outputs=reg_data_dir
) )
with gr.Row(): with gr.Row():
output_dir_input = gr.Textbox( output_dir = gr.Textbox(
label='Output folder', label='Model output folder',
placeholder='Folder to output trained model', placeholder='Folder to output trained model',
) )
output_dir_input_folder = gr.Button( output_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small' '📂', elem_id='open_folder_small'
) )
output_dir_input_folder.click( output_dir_input_folder.click(
get_folder_path, outputs=output_dir_input get_folder_path, outputs=output_dir
) )
logging_dir_input = gr.Textbox( logging_dir = gr.Textbox(
label='Logging folder', label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder', placeholder='Optional: enable logging and output TensorBoard log to this folder',
) )
@ -667,32 +667,32 @@ def dreambooth_tab(
'📂', elem_id='open_folder_small' '📂', elem_id='open_folder_small'
) )
logging_dir_input_folder.click( logging_dir_input_folder.click(
get_folder_path, outputs=logging_dir_input get_folder_path, outputs=logging_dir
) )
train_data_dir_input.change( train_data_dir.change(
remove_doublequote, remove_doublequote,
inputs=[train_data_dir_input], inputs=[train_data_dir],
outputs=[train_data_dir_input], outputs=[train_data_dir],
) )
reg_data_dir_input.change( reg_data_dir.change(
remove_doublequote, remove_doublequote,
inputs=[reg_data_dir_input], inputs=[reg_data_dir],
outputs=[reg_data_dir_input], outputs=[reg_data_dir],
) )
output_dir_input.change( output_dir.change(
remove_doublequote, remove_doublequote,
inputs=[output_dir_input], inputs=[output_dir],
outputs=[output_dir_input], outputs=[output_dir],
) )
logging_dir_input.change( logging_dir.change(
remove_doublequote, remove_doublequote,
inputs=[logging_dir_input], inputs=[logging_dir],
outputs=[logging_dir_input], outputs=[logging_dir],
) )
with gr.Tab('Training parameters'): with gr.Tab('Training parameters'):
with gr.Row(): with gr.Row():
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6) learning_rate = gr.Textbox(label='Learning rate', value=1e-6)
lr_scheduler_input = gr.Dropdown( lr_scheduler = gr.Dropdown(
label='LR Scheduler', label='LR Scheduler',
choices=[ choices=[
'constant', 'constant',
@ -704,21 +704,21 @@ def dreambooth_tab(
], ],
value='constant', value='constant',
) )
lr_warmup_input = gr.Textbox(label='LR warmup', value=0) lr_warmup = gr.Textbox(label='LR warmup', value=0)
with gr.Row(): with gr.Row():
train_batch_size_input = gr.Slider( train_batch_size = gr.Slider(
minimum=1, minimum=1,
maximum=32, maximum=32,
label='Train batch size', label='Train batch size',
value=1, value=1,
step=1, step=1,
) )
epoch_input = gr.Textbox(label='Epoch', value=1) epoch = gr.Textbox(label='Epoch', value=1)
save_every_n_epochs_input = gr.Textbox( save_every_n_epochs = gr.Textbox(
label='Save every N epochs', value=1 label='Save every N epochs', value=1
) )
with gr.Row(): with gr.Row():
mixed_precision_input = gr.Dropdown( mixed_precision = gr.Dropdown(
label='Mixed precision', label='Mixed precision',
choices=[ choices=[
'no', 'no',
@ -727,7 +727,7 @@ def dreambooth_tab(
], ],
value='fp16', value='fp16',
) )
save_precision_input = gr.Dropdown( save_precision = gr.Dropdown(
label='Save precision', label='Save precision',
choices=[ choices=[
'float', 'float',
@ -736,7 +736,7 @@ def dreambooth_tab(
], ],
value='fp16', value='fp16',
) )
num_cpu_threads_per_process_input = gr.Slider( num_cpu_threads_per_process = gr.Slider(
minimum=1, minimum=1,
maximum=os.cpu_count(), maximum=os.cpu_count(),
step=1, step=1,
@ -744,18 +744,18 @@ def dreambooth_tab(
value=os.cpu_count(), value=os.cpu_count(),
) )
with gr.Row(): with gr.Row():
seed_input = gr.Textbox(label='Seed', value=1234) seed = gr.Textbox(label='Seed', value=1234)
max_resolution_input = gr.Textbox( max_resolution = gr.Textbox(
label='Max resolution', label='Max resolution',
value='512,512', value='512,512',
placeholder='512,512', placeholder='512,512',
) )
with gr.Row(): with gr.Row():
caption_extention_input = gr.Textbox( caption_extention = gr.Textbox(
label='Caption Extension', label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption', placeholder='(Optional) Extension for caption files. default: .caption',
) )
stop_text_encoder_training_input = gr.Slider( stop_text_encoder_training = gr.Slider(
minimum=0, minimum=0,
maximum=100, maximum=100,
value=0, value=0,
@ -763,24 +763,24 @@ def dreambooth_tab(
label='Stop text encoder training', label='Stop text encoder training',
) )
with gr.Row(): with gr.Row():
enable_bucket_input = gr.Checkbox( enable_bucket = gr.Checkbox(
label='Enable buckets', value=True label='Enable buckets', value=True
) )
cache_latent_input = gr.Checkbox(label='Cache latent', value=True) cache_latent = gr.Checkbox(label='Cache latent', value=True)
use_8bit_adam_input = gr.Checkbox( use_8bit_adam = gr.Checkbox(
label='Use 8bit adam', value=True label='Use 8bit adam', value=True
) )
xformers_input = gr.Checkbox(label='Use xformers', value=True) xformers = gr.Checkbox(label='Use xformers', value=True)
with gr.Accordion('Advanced Configuration', open=False): with gr.Accordion('Advanced Configuration', open=False):
with gr.Row(): with gr.Row():
full_fp16_input = gr.Checkbox( full_fp16 = gr.Checkbox(
label='Full fp16 training (experimental)', value=False label='Full fp16 training (experimental)', value=False
) )
no_token_padding_input = gr.Checkbox( no_token_padding = gr.Checkbox(
label='No token padding', value=False label='No token padding', value=False
) )
gradient_checkpointing_input = gr.Checkbox( gradient_checkpointing = gr.Checkbox(
label='Gradient checkpointing', value=False label='Gradient checkpointing', value=False
) )
@ -798,7 +798,7 @@ def dreambooth_tab(
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
inputs=[color_aug], inputs=[color_aug],
outputs=[cache_latent_input], outputs=[cache_latent],
) )
clip_skip = gr.Slider( clip_skip = gr.Slider(
label='Clip skip', value='1', minimum=1, maximum=12, step=1 label='Clip skip', value='1', minimum=1, maximum=12, step=1
@ -824,43 +824,43 @@ def dreambooth_tab(
'This section provide Dreambooth tools to help setup your dataset...' 'This section provide Dreambooth tools to help setup your dataset...'
) )
gradio_dreambooth_folder_creation_tab( gradio_dreambooth_folder_creation_tab(
train_data_dir_input=train_data_dir_input, train_data_dir_input=train_data_dir,
reg_data_dir_input=reg_data_dir_input, reg_data_dir_input=reg_data_dir,
output_dir_input=output_dir_input, output_dir_input=output_dir,
logging_dir_input=logging_dir_input, logging_dir_input=logging_dir,
) )
button_run = gr.Button('Train model') button_run = gr.Button('Train model')
settings_list = [ settings_list = [
pretrained_model_name_or_path_input, pretrained_model_name_or_path,
v2_input, v2,
v_parameterization_input, v_parameterization,
logging_dir_input, logging_dir,
train_data_dir_input, train_data_dir,
reg_data_dir_input, reg_data_dir,
output_dir_input, output_dir,
max_resolution_input, max_resolution,
learning_rate_input, learning_rate,
lr_scheduler_input, lr_scheduler,
lr_warmup_input, lr_warmup,
train_batch_size_input, train_batch_size,
epoch_input, epoch,
save_every_n_epochs_input, save_every_n_epochs,
mixed_precision_input, mixed_precision,
save_precision_input, save_precision,
seed_input, seed,
num_cpu_threads_per_process_input, num_cpu_threads_per_process,
cache_latent_input, cache_latent,
caption_extention_input, caption_extention,
enable_bucket_input, enable_bucket,
gradient_checkpointing_input, gradient_checkpointing,
full_fp16_input, full_fp16,
no_token_padding_input, no_token_padding,
stop_text_encoder_training_input, stop_text_encoder_training,
use_8bit_adam_input, use_8bit_adam,
xformers_input, xformers,
save_model_as_dropdown, save_model_as,
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
@ -895,10 +895,10 @@ def dreambooth_tab(
) )
return ( return (
train_data_dir_input, train_data_dir,
reg_data_dir_input, reg_data_dir,
output_dir_input, output_dir,
logging_dir_input, logging_dir,
) )