From b44f075f603a97209cc3f93028370f28f84829de Mon Sep 17 00:00:00 2001 From: bmaltais Date: Thu, 29 Dec 2022 14:00:02 -0500 Subject: [PATCH] Implement open and save config for LoRA --- library/common_gui.py | 15 +++ lora_gui.py | 247 +++++++++++------------------------------- setup.py | 2 +- 3 files changed, 82 insertions(+), 182 deletions(-) diff --git a/library/common_gui.py b/library/common_gui.py index 7cc6efa..ae1e647 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -20,6 +20,21 @@ def get_file_path(file_path='', defaultextension='.json'): return file_path +def get_any_file_path(file_path=''): + current_file_path = file_path + # print(f'current file path: {current_file_path}') + + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + file_path = filedialog.askopenfilename() + root.destroy() + + if file_path == '': + file_path = current_file_path + + return file_path + def remove_doublequote(file_path): if file_path != None: diff --git a/lora_gui.py b/lora_gui.py index 0eaa8cc..68ff4da 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -15,6 +15,7 @@ from library.common_gui import ( get_folder_path, remove_doublequote, get_file_path, + get_any_file_path, get_saveasfile_path, ) from library.dreambooth_folder_creation_gui import ( @@ -64,7 +65,7 @@ def save_configuration( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim ): original_file_path = file_path @@ -117,6 +118,10 @@ def save_configuration( 'save_state': save_state, 'resume': resume, 'prior_loss_weight': prior_loss_weight, + 'text_encoder_lr': text_encoder_lr, + 'unet_lr': unet_lr, + 'network_train': network_train, + 'network_dim': network_dim } # Save the data to the selected file @@ -159,7 +164,7 @@ def open_configuration( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim ): original_file_path = file_path @@ -213,6 +218,10 @@ def open_configuration( my_data.get('save_state', save_state), my_data.get('resume', resume), my_data.get('prior_loss_weight', prior_loss_weight), + my_data.get('text_encoder_lr', text_encoder_lr), + my_data.get('unet_lr', unet_lr), + my_data.get('network_train', network_train), + my_data.get('network_dim', network_dim), ) @@ -548,10 +557,10 @@ def lora_tab( label='Pretrained model name or path', placeholder='enter the path to custom model or name of pretrained model', ) - pretrained_model_name_or_path_fille = gr.Button( + pretrained_model_name_or_path_file = gr.Button( document_symbol, elem_id='open_folder_small' ) - pretrained_model_name_or_path_fille.click( + pretrained_model_name_or_path_file.click( get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input, @@ -586,6 +595,19 @@ def lora_tab( ], value='same as source model', ) + with gr.Row(): + lora_network_weights = gr.Textbox( + label='LoRA network weights', + placeholder='{Optional) Path to existing LoRA network weights to resume training}', + ) + lora_network_weights_file = gr.Button( + document_symbol, elem_id='open_folder_small' + ) + lora_network_weights_file.click( + get_any_file_path, + inputs=[lora_network_weights], + outputs=lora_network_weights, + ) with gr.Row(): v2_input = gr.Checkbox(label='v2', value=True) v_parameterization_input = gr.Checkbox( @@ -812,200 +834,63 @@ def lora_tab( gradio_dataset_balancing_tab() button_run = gr.Button('Train model') + + settings_list = [ + pretrained_model_name_or_path_input, + v2_input, + v_parameterization_input, + logging_dir_input, + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + max_resolution_input, + learning_rate_input, + lr_scheduler_input, + lr_warmup_input, + train_batch_size_input, + epoch_input, + save_every_n_epochs_input, + mixed_precision_input, + save_precision_input, + seed_input, + num_cpu_threads_per_process_input, + cache_latent_input, + caption_extention_input, + enable_bucket_input, + gradient_checkpointing_input, + full_fp16_input, + no_token_padding_input, + stop_text_encoder_training_input, + use_8bit_adam_input, + xformers_input, + save_model_as_dropdown, + shuffle_caption, + save_state, + resume, + prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim + ] button_open_config.click( open_configuration, - inputs=[ - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - ], - outputs=[ - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - ], + inputs=[config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, ) button_save_config.click( save_configuration, - inputs=[ - dummy_db_false, - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - ], + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name], ) button_save_as_config.click( save_configuration, - inputs=[ - dummy_db_true, - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - ], + inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name], ) button_run.click( train_model, - inputs=[ - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - text_encoder_lr, unet_lr, network_train, network_dim - ], + inputs=settings_list, ) return ( diff --git a/setup.py b/setup.py index 96d88fb..8965557 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,3 @@ from setuptools import setup, find_packages -setup(name = "library", version="1.0.0", packages = find_packages()) \ No newline at end of file +setup(name = "library", version="1.0.1", packages = find_packages()) \ No newline at end of file