diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 073f23d..4af433e 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -28,8 +28,9 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, is_valid_config, show_message_box, + check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, ) +from library.common_utilities import is_valid_config from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) @@ -228,7 +229,8 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path(file_path, filedialog_type="json") + print(f"File path: {file_path}") + file_path = get_file_path_gradio_wrapper(file_path) if not file_path == '' and file_path is not None: with open(file_path, 'r') as f: @@ -836,15 +838,15 @@ def dreambooth_tab( ] button_open_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(*args), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - open_configuration, - inputs=[dummy_db_false, config_file_name] + settings_list, + lambda *args, **kwargs: open_configuration(*args), + inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) diff --git a/finetune_gui.py b/finetune_gui.py index b085928..18c77f9 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -20,7 +20,7 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, + check_if_model_exist, get_file_path_gradio_wrapper, ) from library.tensorboard_gui import ( gradio_tensorboard, @@ -231,9 +231,9 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path(file_path) + file_path = get_file_path_gradio_wrapper(file_path) - if not file_path == '' and not file_path == None: + if not file_path == '' and file_path is not None: # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) @@ -799,14 +799,14 @@ def finetune_tab(): button_run.click(train_model, inputs=settings_list) button_open_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, diff --git a/library/common_gui.py b/library/common_gui.py index 857da9b..4f7fb93 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -89,23 +89,6 @@ def check_if_model_exist(output_name, output_dir, save_model_as): return False - -def is_valid_config(data): - # Check if the data is a dictionary - if not isinstance(data, dict): - return False - - # Add checks for expected keys and valid values - # For example, check if 'use_8bit_adam' is a boolean - if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): - return False - - # Add more checks for other keys as needed - - # If all checks pass, return True - return True - - def update_my_data(my_data): # Update the optimizer based on the use_8bit_adam flag use_8bit_adam = my_data.get('use_8bit_adam', False) @@ -155,6 +138,24 @@ def update_my_data(my_data): # # If no extension files were found, return False # return False +def get_file_path_gradio_wrapper(file_path, filedialog_type="all"): + file_extension = os.path.splitext(file_path)[-1].lower() + + filetype_filters = { + 'db': ['.db'], + 'json': ['.json'], + 'lora': ['.pt', '.ckpt', '.safetensors'], + } + + # Find the appropriate filedialog_type based on the file extension + filedialog_type = 'all' + for key, extensions in filetype_filters.items(): + if file_extension in extensions: + filedialog_type = key + break + + return get_file_path(file_path, filedialog_type) + def get_file_path(file_path='', filedialog_type="lora"): current_file_path = file_path diff --git a/library/common_utilities.py b/library/common_utilities.py new file mode 100644 index 0000000..ea1979c --- /dev/null +++ b/library/common_utilities.py @@ -0,0 +1,14 @@ +def is_valid_config(data): + # Check if the data is a dictionary + if not isinstance(data, dict): + return False + + # Add checks for expected keys and valid values + # For example, check if 'use_8bit_adam' is a boolean + if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): + return False + + # Add more checks for other keys as needed + + # If all checks pass, return True + return True diff --git a/library/convert_model_gui.py b/library/convert_model_gui.py index dd43b0e..2450eb5 100644 --- a/library/convert_model_gui.py +++ b/library/convert_model_gui.py @@ -4,22 +4,22 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, get_file_path +from .common_gui import get_folder_path, get_file_path, get_file_path_gradio_wrapper folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +document_symbol = '\U0001F4C4' # 📄 PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' def convert_model( - source_model_input, - source_model_type, - target_model_folder_input, - target_model_name_input, - target_model_type, - target_save_precision_type, + source_model_input, + source_model_type, + target_model_folder_input, + target_model_name_input, + target_model_type, + target_save_precision_type, ): # Check for caption_text_input if source_model_type == '': @@ -61,8 +61,8 @@ def convert_model( run_cmd += f' --{target_save_precision_type}' if ( - target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' ): run_cmd += f' --reference_model="{source_model_type}"' @@ -72,8 +72,8 @@ def convert_model( run_cmd += f' "{source_model_input}"' if ( - target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' ): target_model_path = os.path.join( target_model_folder_input, target_model_name_input @@ -95,8 +95,8 @@ def convert_model( subprocess.run(run_cmd) if ( - not target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' + not target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' ): v2_models = [ @@ -180,7 +180,8 @@ def gradio_convert_model_tab(): document_symbol, elem_id='open_folder_small' ) button_source_model_file.click( - get_file_path, + lambda input1, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.normpath(input1)), inputs=[source_model_input], outputs=source_model_input, show_progress=False, diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 6bc1f30..dbad3d1 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -4,25 +4,25 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, get_saveasfile_path, + get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, ) folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +document_symbol = '\U0001F4C4' # 📄 PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' def extract_lora( - model_tuned, - model_org, - save_to, - save_precision, - dim, - v2, - conv_dim, - device, + model_tuned, + model_org, + save_to, + save_precision, + dim, + v2, + conv_dim, + device, ): # Check for caption_text_input if model_tuned == '': @@ -43,7 +43,7 @@ def extract_lora( return run_cmd = ( - f'{PYTHON} "{os.path.join("networks","extract_lora_from_models.py")}"' + f'{PYTHON} "{os.path.join("networks", "extract_lora_from_models.py")}"' ) run_cmd += f' --save_precision {save_precision}' run_cmd += f' --save_to "{save_to}"' @@ -90,7 +90,8 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_model_tuned_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[model_tuned, model_ext, model_ext_name], outputs=model_tuned, show_progress=False, @@ -105,7 +106,8 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_model_org_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[model_org, model_ext, model_ext_name], outputs=model_org, show_progress=False, diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py index f95bf96..491e8ac 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/library/extract_lycoris_locon_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, get_saveasfile_path, + get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, ) folder_symbol = '\U0001f4c2' # 📂 @@ -136,7 +136,8 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_db_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[db_model, model_ext, model_ext_name], outputs=db_model, show_progress=False, @@ -151,7 +152,8 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_base_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[base_model, model_ext, model_ext_name], outputs=base_model, show_progress=False, diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index e5f6b9b..1a0edf9 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, get_saveasfile_path, + get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, ) folder_symbol = '\U0001f4c2' # 📂 @@ -81,7 +81,8 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -96,7 +97,8 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index a47e407..e8321d8 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_file_path, get_saveasfile_path +from .common_gui import get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' folder_symbol = '\U0001f4c2' # 📂 @@ -96,7 +96,8 @@ def gradio_resize_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[model, lora_ext, lora_ext_name], outputs=model, show_progress=False, diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index a2b040c..be127b3 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, get_saveasfile_path, + get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, ) folder_symbol = '\U0001f4c2' # 📂 @@ -87,7 +87,8 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -102,7 +103,8 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py index a72160e..4acf101 100644 --- a/library/verify_lora_gui.py +++ b/library/verify_lora_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, + get_file_path, get_file_path_gradio_wrapper, ) PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -68,7 +68,8 @@ def gradio_verify_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_model, lora_ext, lora_ext_name], outputs=lora_model, show_progress=False, diff --git a/lora_gui.py b/lora_gui.py index 0c8129e..8990458 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -28,7 +28,7 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, show_message_box, + check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, ) from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dreambooth_folder_creation_gui import ( @@ -254,7 +254,7 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path(file_path) + file_path = get_file_path_gradio_wrapper(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -1031,14 +1031,14 @@ def lora_tab( ] button_open_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, ) button_load_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 5c82818..f434494 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -28,7 +28,7 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, + check_if_model_exist, get_file_path_gradio_wrapper, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -240,7 +240,7 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path(file_path) + file_path = get_file_path_gradio_wrapper(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -673,7 +673,7 @@ def ti_tab( ) weights_file_input = gr.Button('📂', elem_id='open_folder_small') weights_file_input.click( - get_file_path, + lambda *args, **kwargs: get_file_path_gradio_wrapper, outputs=weights, show_progress=False, ) @@ -899,14 +899,14 @@ def ti_tab( ] button_open_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False,