From 402cb51ec046c0457a7c1e45136037e7fea6e96c Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 9 Jan 2023 08:08:47 -0500 Subject: [PATCH] refactor Dreambooth gui code --- dreambooth_gui.py | 196 +++++++++++++++++++++++----------------------- 1 file changed, 98 insertions(+), 98 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 75e6868..94352de 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -536,10 +536,10 @@ def UI(username, password): def dreambooth_tab( - train_data_dir_input=gr.Textbox(), - reg_data_dir_input=gr.Textbox(), - output_dir_input=gr.Textbox(), - logging_dir_input=gr.Textbox(), + train_data_dir=gr.Textbox(), + reg_data_dir=gr.Textbox(), + output_dir=gr.Textbox(), + logging_dir=gr.Textbox(), ): dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) @@ -564,24 +564,24 @@ def dreambooth_tab( with gr.Tab('Source model'): # Define the input elements with gr.Row(): - pretrained_model_name_or_path_input = gr.Textbox( + pretrained_model_name_or_path = gr.Textbox( label='Pretrained model name or path', placeholder='enter the path to custom model or name of pretrained model', ) - pretrained_model_name_or_path_fille = gr.Button( + pretrained_model_name_or_path_file = gr.Button( document_symbol, elem_id='open_folder_small' ) - pretrained_model_name_or_path_fille.click( + pretrained_model_name_or_path_file.click( get_any_file_path, - inputs=[pretrained_model_name_or_path_input], - outputs=pretrained_model_name_or_path_input, + inputs=[pretrained_model_name_or_path], + outputs=pretrained_model_name_or_path, ) pretrained_model_name_or_path_folder = gr.Button( folder_symbol, elem_id='open_folder_small' ) pretrained_model_name_or_path_folder.click( get_folder_path, - outputs=pretrained_model_name_or_path_input, + outputs=pretrained_model_name_or_path, ) model_list = gr.Dropdown( label='(Optional) Model Quick Pick', @@ -595,7 +595,7 @@ def dreambooth_tab( 'CompVis/stable-diffusion-v1-4', ], ) - save_model_as_dropdown = gr.Dropdown( + save_model_as = gr.Dropdown( label='Save trained model as', choices=[ 'same as source model', @@ -607,28 +607,28 @@ def dreambooth_tab( value='same as source model', ) with gr.Row(): - v2_input = gr.Checkbox(label='v2', value=True) - v_parameterization_input = gr.Checkbox( + v2 = gr.Checkbox(label='v2', value=True) + v_parameterization = gr.Checkbox( label='v_parameterization', value=False ) - pretrained_model_name_or_path_input.change( + pretrained_model_name_or_path.change( remove_doublequote, - inputs=[pretrained_model_name_or_path_input], - outputs=[pretrained_model_name_or_path_input], + inputs=[pretrained_model_name_or_path], + outputs=[pretrained_model_name_or_path], ) model_list.change( set_pretrained_model_name_or_path_input, - inputs=[model_list, v2_input, v_parameterization_input], + inputs=[model_list, v2, v_parameterization], outputs=[ - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, + pretrained_model_name_or_path, + v2, + v_parameterization, ], ) with gr.Tab('Folders'): with gr.Row(): - train_data_dir_input = gr.Textbox( + train_data_dir = gr.Textbox( label='Image folder', placeholder='Folder where the training folders containing the images are located', ) @@ -636,9 +636,9 @@ def dreambooth_tab( '📂', elem_id='open_folder_small' ) train_data_dir_input_folder.click( - get_folder_path, outputs=train_data_dir_input + get_folder_path, outputs=train_data_dir ) - reg_data_dir_input = gr.Textbox( + reg_data_dir = gr.Textbox( label='Regularisation folder', placeholder='(Optional) Folder where where the regularization folders containing the images are located', ) @@ -646,20 +646,20 @@ def dreambooth_tab( '📂', elem_id='open_folder_small' ) reg_data_dir_input_folder.click( - get_folder_path, outputs=reg_data_dir_input + get_folder_path, outputs=reg_data_dir ) with gr.Row(): - output_dir_input = gr.Textbox( - label='Output folder', + output_dir = gr.Textbox( + label='Model output folder', placeholder='Folder to output trained model', ) output_dir_input_folder = gr.Button( '📂', elem_id='open_folder_small' ) output_dir_input_folder.click( - get_folder_path, outputs=output_dir_input + get_folder_path, outputs=output_dir ) - logging_dir_input = gr.Textbox( + logging_dir = gr.Textbox( label='Logging folder', placeholder='Optional: enable logging and output TensorBoard log to this folder', ) @@ -667,32 +667,32 @@ def dreambooth_tab( '📂', elem_id='open_folder_small' ) logging_dir_input_folder.click( - get_folder_path, outputs=logging_dir_input + get_folder_path, outputs=logging_dir ) - train_data_dir_input.change( + train_data_dir.change( remove_doublequote, - inputs=[train_data_dir_input], - outputs=[train_data_dir_input], + inputs=[train_data_dir], + outputs=[train_data_dir], ) - reg_data_dir_input.change( + reg_data_dir.change( remove_doublequote, - inputs=[reg_data_dir_input], - outputs=[reg_data_dir_input], + inputs=[reg_data_dir], + outputs=[reg_data_dir], ) - output_dir_input.change( + output_dir.change( remove_doublequote, - inputs=[output_dir_input], - outputs=[output_dir_input], + inputs=[output_dir], + outputs=[output_dir], ) - logging_dir_input.change( + logging_dir.change( remove_doublequote, - inputs=[logging_dir_input], - outputs=[logging_dir_input], + inputs=[logging_dir], + outputs=[logging_dir], ) with gr.Tab('Training parameters'): with gr.Row(): - learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6) - lr_scheduler_input = gr.Dropdown( + learning_rate = gr.Textbox(label='Learning rate', value=1e-6) + lr_scheduler = gr.Dropdown( label='LR Scheduler', choices=[ 'constant', @@ -704,21 +704,21 @@ def dreambooth_tab( ], value='constant', ) - lr_warmup_input = gr.Textbox(label='LR warmup', value=0) + lr_warmup = gr.Textbox(label='LR warmup', value=0) with gr.Row(): - train_batch_size_input = gr.Slider( + train_batch_size = gr.Slider( minimum=1, maximum=32, label='Train batch size', value=1, step=1, ) - epoch_input = gr.Textbox(label='Epoch', value=1) - save_every_n_epochs_input = gr.Textbox( + epoch = gr.Textbox(label='Epoch', value=1) + save_every_n_epochs = gr.Textbox( label='Save every N epochs', value=1 ) with gr.Row(): - mixed_precision_input = gr.Dropdown( + mixed_precision = gr.Dropdown( label='Mixed precision', choices=[ 'no', @@ -727,7 +727,7 @@ def dreambooth_tab( ], value='fp16', ) - save_precision_input = gr.Dropdown( + save_precision = gr.Dropdown( label='Save precision', choices=[ 'float', @@ -736,7 +736,7 @@ def dreambooth_tab( ], value='fp16', ) - num_cpu_threads_per_process_input = gr.Slider( + num_cpu_threads_per_process = gr.Slider( minimum=1, maximum=os.cpu_count(), step=1, @@ -744,18 +744,18 @@ def dreambooth_tab( value=os.cpu_count(), ) with gr.Row(): - seed_input = gr.Textbox(label='Seed', value=1234) - max_resolution_input = gr.Textbox( + seed = gr.Textbox(label='Seed', value=1234) + max_resolution = gr.Textbox( label='Max resolution', value='512,512', placeholder='512,512', ) with gr.Row(): - caption_extention_input = gr.Textbox( + caption_extention = gr.Textbox( label='Caption Extension', placeholder='(Optional) Extension for caption files. default: .caption', ) - stop_text_encoder_training_input = gr.Slider( + stop_text_encoder_training = gr.Slider( minimum=0, maximum=100, value=0, @@ -763,24 +763,24 @@ def dreambooth_tab( label='Stop text encoder training', ) with gr.Row(): - enable_bucket_input = gr.Checkbox( + enable_bucket = gr.Checkbox( label='Enable buckets', value=True ) - cache_latent_input = gr.Checkbox(label='Cache latent', value=True) - use_8bit_adam_input = gr.Checkbox( + cache_latent = gr.Checkbox(label='Cache latent', value=True) + use_8bit_adam = gr.Checkbox( label='Use 8bit adam', value=True ) - xformers_input = gr.Checkbox(label='Use xformers', value=True) + xformers = gr.Checkbox(label='Use xformers', value=True) with gr.Accordion('Advanced Configuration', open=False): with gr.Row(): - full_fp16_input = gr.Checkbox( + full_fp16 = gr.Checkbox( label='Full fp16 training (experimental)', value=False ) - no_token_padding_input = gr.Checkbox( + no_token_padding = gr.Checkbox( label='No token padding', value=False ) - gradient_checkpointing_input = gr.Checkbox( + gradient_checkpointing = gr.Checkbox( label='Gradient checkpointing', value=False ) @@ -798,7 +798,7 @@ def dreambooth_tab( color_aug.change( color_aug_changed, inputs=[color_aug], - outputs=[cache_latent_input], + outputs=[cache_latent], ) clip_skip = gr.Slider( label='Clip skip', value='1', minimum=1, maximum=12, step=1 @@ -824,43 +824,43 @@ def dreambooth_tab( 'This section provide Dreambooth tools to help setup your dataset...' ) gradio_dreambooth_folder_creation_tab( - train_data_dir_input=train_data_dir_input, - reg_data_dir_input=reg_data_dir_input, - output_dir_input=output_dir_input, - logging_dir_input=logging_dir_input, + train_data_dir_input=train_data_dir, + reg_data_dir_input=reg_data_dir, + output_dir_input=output_dir, + logging_dir_input=logging_dir, ) button_run = gr.Button('Train model') settings_list = [ - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latent, + caption_extention, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + use_8bit_adam, + xformers, + save_model_as, shuffle_caption, save_state, resume, @@ -895,10 +895,10 @@ def dreambooth_tab( ) return ( - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - logging_dir_input, + train_data_dir, + reg_data_dir, + output_dir, + logging_dir, )