From a65555ea67c8e1519977cb91bfd9ba648350ee51 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Fri, 10 Mar 2023 20:05:38 -0500 Subject: [PATCH] Add support to load a config without opening the UI to get the file name --- dreambooth_gui.py | 19 ++++++++++++++++--- finetune_gui.py | 31 ++++++++++++++++++++++--------- library/common_gui.py | 3 +++ lora_gui.py | 19 ++++++++++++++++--- textual_inversion_gui.py | 19 ++++++++++++++++--- 5 files changed, 73 insertions(+), 18 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index df40784..dee017c 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -152,6 +152,7 @@ def save_configuration( def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -213,9 +214,13 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -231,7 +236,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) return tuple(values) @@ -506,6 +511,7 @@ def dreambooth_tab( button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -775,7 +781,14 @@ def dreambooth_tab( button_open_config.click( open_configuration, - inputs=[config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) diff --git a/finetune_gui.py b/finetune_gui.py index 3ef1cbd..59dffd8 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -149,7 +149,8 @@ def save_configuration( return file_path -def open_config_file( +def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -217,9 +218,13 @@ def open_config_file( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -235,7 +240,7 @@ def open_config_file( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) return tuple(values) @@ -492,8 +497,8 @@ def remove_doublequote(file_path): def finetune_tab(): - dummy_ft_true = gr.Label(value=True, visible=False) - dummy_ft_false = gr.Label(value=False, visible=False) + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) gr.Markdown('Train a custom model using kohya finetune python code...') ( @@ -501,6 +506,7 @@ def finetune_tab(): button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -770,22 +776,29 @@ def finetune_tab(): button_run.click(train_model, inputs=settings_list) button_open_config.click( - open_config_file, - inputs=[config_file_name] + settings_list, + open_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_save_config.click( save_configuration, - inputs=[dummy_ft_false, config_file_name] + settings_list, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name], show_progress=False, ) button_save_as_config.click( save_configuration, - inputs=[dummy_ft_true, config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name], show_progress=False, ) diff --git a/library/common_gui.py b/library/common_gui.py index b22594f..e200141 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -405,11 +405,14 @@ def gradio_config(): placeholder="type the configuration file path or use the 'Open' button above to select it...", interactive=True, ) + button_load_config = gr.Button('Load 💾', elem_id='open_folder') + config_file_name.change(remove_doublequote, inputs=[config_file_name], outputs=[config_file_name]) return ( button_open_config, button_save_config, button_save_as_config, config_file_name, + button_load_config, ) diff --git a/lora_gui.py b/lora_gui.py index 49918de..23da712 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -168,6 +168,7 @@ def save_configuration( def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -239,9 +240,13 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -257,7 +262,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) # This next section is about making the LoCon parameters visible if LoRA_type = 'Standard' @@ -610,6 +615,7 @@ def lora_tab( button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -974,7 +980,14 @@ def lora_tab( button_open_config.click( open_configuration, - inputs=[config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list + [LoCon_row], + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, ) diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index c92bdc0..ed3c33a 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -158,6 +158,7 @@ def save_configuration( def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -225,9 +226,13 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -243,7 +248,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) return tuple(values) @@ -548,6 +553,7 @@ def ti_tab( button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -865,7 +871,14 @@ def ti_tab( button_open_config.click( open_configuration, - inputs=[config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, )