refactor Dreambooth gui code

This commit is contained in:
bmaltais 2023-01-09 08:08:47 -05:00
parent 442eb7a292
commit 402cb51ec0

View File

@ -536,10 +536,10 @@ def UI(username, password):
def dreambooth_tab(
train_data_dir_input=gr.Textbox(),
reg_data_dir_input=gr.Textbox(),
output_dir_input=gr.Textbox(),
logging_dir_input=gr.Textbox(),
train_data_dir=gr.Textbox(),
reg_data_dir=gr.Textbox(),
output_dir=gr.Textbox(),
logging_dir=gr.Textbox(),
):
dummy_db_true = gr.Label(value=True, visible=False)
dummy_db_false = gr.Label(value=False, visible=False)
@ -564,24 +564,24 @@ def dreambooth_tab(
with gr.Tab('Source model'):
# Define the input elements
with gr.Row():
pretrained_model_name_or_path_input = gr.Textbox(
pretrained_model_name_or_path = gr.Textbox(
label='Pretrained model name or path',
placeholder='enter the path to custom model or name of pretrained model',
)
pretrained_model_name_or_path_fille = gr.Button(
pretrained_model_name_or_path_file = gr.Button(
document_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_fille.click(
pretrained_model_name_or_path_file.click(
get_any_file_path,
inputs=[pretrained_model_name_or_path_input],
outputs=pretrained_model_name_or_path_input,
inputs=[pretrained_model_name_or_path],
outputs=pretrained_model_name_or_path,
)
pretrained_model_name_or_path_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_folder.click(
get_folder_path,
outputs=pretrained_model_name_or_path_input,
outputs=pretrained_model_name_or_path,
)
model_list = gr.Dropdown(
label='(Optional) Model Quick Pick',
@ -595,7 +595,7 @@ def dreambooth_tab(
'CompVis/stable-diffusion-v1-4',
],
)
save_model_as_dropdown = gr.Dropdown(
save_model_as = gr.Dropdown(
label='Save trained model as',
choices=[
'same as source model',
@ -607,28 +607,28 @@ def dreambooth_tab(
value='same as source model',
)
with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox(
v2 = gr.Checkbox(label='v2', value=True)
v_parameterization = gr.Checkbox(
label='v_parameterization', value=False
)
pretrained_model_name_or_path_input.change(
pretrained_model_name_or_path.change(
remove_doublequote,
inputs=[pretrained_model_name_or_path_input],
outputs=[pretrained_model_name_or_path_input],
inputs=[pretrained_model_name_or_path],
outputs=[pretrained_model_name_or_path],
)
model_list.change(
set_pretrained_model_name_or_path_input,
inputs=[model_list, v2_input, v_parameterization_input],
inputs=[model_list, v2, v_parameterization],
outputs=[
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
pretrained_model_name_or_path,
v2,
v_parameterization,
],
)
with gr.Tab('Folders'):
with gr.Row():
train_data_dir_input = gr.Textbox(
train_data_dir = gr.Textbox(
label='Image folder',
placeholder='Folder where the training folders containing the images are located',
)
@ -636,9 +636,9 @@ def dreambooth_tab(
'📂', elem_id='open_folder_small'
)
train_data_dir_input_folder.click(
get_folder_path, outputs=train_data_dir_input
get_folder_path, outputs=train_data_dir
)
reg_data_dir_input = gr.Textbox(
reg_data_dir = gr.Textbox(
label='Regularisation folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
)
@ -646,20 +646,20 @@ def dreambooth_tab(
'📂', elem_id='open_folder_small'
)
reg_data_dir_input_folder.click(
get_folder_path, outputs=reg_data_dir_input
get_folder_path, outputs=reg_data_dir
)
with gr.Row():
output_dir_input = gr.Textbox(
label='Output folder',
output_dir = gr.Textbox(
label='Model output folder',
placeholder='Folder to output trained model',
)
output_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
output_dir_input_folder.click(
get_folder_path, outputs=output_dir_input
get_folder_path, outputs=output_dir
)
logging_dir_input = gr.Textbox(
logging_dir = gr.Textbox(
label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder',
)
@ -667,32 +667,32 @@ def dreambooth_tab(
'📂', elem_id='open_folder_small'
)
logging_dir_input_folder.click(
get_folder_path, outputs=logging_dir_input
get_folder_path, outputs=logging_dir
)
train_data_dir_input.change(
train_data_dir.change(
remove_doublequote,
inputs=[train_data_dir_input],
outputs=[train_data_dir_input],
inputs=[train_data_dir],
outputs=[train_data_dir],
)
reg_data_dir_input.change(
reg_data_dir.change(
remove_doublequote,
inputs=[reg_data_dir_input],
outputs=[reg_data_dir_input],
inputs=[reg_data_dir],
outputs=[reg_data_dir],
)
output_dir_input.change(
output_dir.change(
remove_doublequote,
inputs=[output_dir_input],
outputs=[output_dir_input],
inputs=[output_dir],
outputs=[output_dir],
)
logging_dir_input.change(
logging_dir.change(
remove_doublequote,
inputs=[logging_dir_input],
outputs=[logging_dir_input],
inputs=[logging_dir],
outputs=[logging_dir],
)
with gr.Tab('Training parameters'):
with gr.Row():
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
lr_scheduler_input = gr.Dropdown(
learning_rate = gr.Textbox(label='Learning rate', value=1e-6)
lr_scheduler = gr.Dropdown(
label='LR Scheduler',
choices=[
'constant',
@ -704,21 +704,21 @@ def dreambooth_tab(
],
value='constant',
)
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
lr_warmup = gr.Textbox(label='LR warmup', value=0)
with gr.Row():
train_batch_size_input = gr.Slider(
train_batch_size = gr.Slider(
minimum=1,
maximum=32,
label='Train batch size',
value=1,
step=1,
)
epoch_input = gr.Textbox(label='Epoch', value=1)
save_every_n_epochs_input = gr.Textbox(
epoch = gr.Textbox(label='Epoch', value=1)
save_every_n_epochs = gr.Textbox(
label='Save every N epochs', value=1
)
with gr.Row():
mixed_precision_input = gr.Dropdown(
mixed_precision = gr.Dropdown(
label='Mixed precision',
choices=[
'no',
@ -727,7 +727,7 @@ def dreambooth_tab(
],
value='fp16',
)
save_precision_input = gr.Dropdown(
save_precision = gr.Dropdown(
label='Save precision',
choices=[
'float',
@ -736,7 +736,7 @@ def dreambooth_tab(
],
value='fp16',
)
num_cpu_threads_per_process_input = gr.Slider(
num_cpu_threads_per_process = gr.Slider(
minimum=1,
maximum=os.cpu_count(),
step=1,
@ -744,18 +744,18 @@ def dreambooth_tab(
value=os.cpu_count(),
)
with gr.Row():
seed_input = gr.Textbox(label='Seed', value=1234)
max_resolution_input = gr.Textbox(
seed = gr.Textbox(label='Seed', value=1234)
max_resolution = gr.Textbox(
label='Max resolution',
value='512,512',
placeholder='512,512',
)
with gr.Row():
caption_extention_input = gr.Textbox(
caption_extention = gr.Textbox(
label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption',
)
stop_text_encoder_training_input = gr.Slider(
stop_text_encoder_training = gr.Slider(
minimum=0,
maximum=100,
value=0,
@ -763,24 +763,24 @@ def dreambooth_tab(
label='Stop text encoder training',
)
with gr.Row():
enable_bucket_input = gr.Checkbox(
enable_bucket = gr.Checkbox(
label='Enable buckets', value=True
)
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
use_8bit_adam_input = gr.Checkbox(
cache_latent = gr.Checkbox(label='Cache latent', value=True)
use_8bit_adam = gr.Checkbox(
label='Use 8bit adam', value=True
)
xformers_input = gr.Checkbox(label='Use xformers', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True)
with gr.Accordion('Advanced Configuration', open=False):
with gr.Row():
full_fp16_input = gr.Checkbox(
full_fp16 = gr.Checkbox(
label='Full fp16 training (experimental)', value=False
)
no_token_padding_input = gr.Checkbox(
no_token_padding = gr.Checkbox(
label='No token padding', value=False
)
gradient_checkpointing_input = gr.Checkbox(
gradient_checkpointing = gr.Checkbox(
label='Gradient checkpointing', value=False
)
@ -798,7 +798,7 @@ def dreambooth_tab(
color_aug.change(
color_aug_changed,
inputs=[color_aug],
outputs=[cache_latent_input],
outputs=[cache_latent],
)
clip_skip = gr.Slider(
label='Clip skip', value='1', minimum=1, maximum=12, step=1
@ -824,43 +824,43 @@ def dreambooth_tab(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
train_data_dir_input=train_data_dir,
reg_data_dir_input=reg_data_dir,
output_dir_input=output_dir,
logging_dir_input=logging_dir,
)
button_run = gr.Button('Train model')
settings_list = [
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
logging_dir_input,
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
cache_latent_input,
caption_extention_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown,
pretrained_model_name_or_path,
v2,
v_parameterization,
logging_dir,
train_data_dir,
reg_data_dir,
output_dir,
max_resolution,
learning_rate,
lr_scheduler,
lr_warmup,
train_batch_size,
epoch,
save_every_n_epochs,
mixed_precision,
save_precision,
seed,
num_cpu_threads_per_process,
cache_latent,
caption_extention,
enable_bucket,
gradient_checkpointing,
full_fp16,
no_token_padding,
stop_text_encoder_training,
use_8bit_adam,
xformers,
save_model_as,
shuffle_caption,
save_state,
resume,
@ -895,10 +895,10 @@ def dreambooth_tab(
)
return (
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
train_data_dir,
reg_data_dir,
output_dir,
logging_dir,
)