From 7b5639cff5f607a77f536b6ed46979a0d1220531 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Thu, 30 Mar 2023 01:40:00 -0700 Subject: [PATCH 1/4] Huge WIP This is a massive WIP and should not be trusted or used right now. However, major milestones have been crossed. Both message boxes and file dialogs are now properly subprocessed and work on macOS. I think by extension, it may work on runpod environments as well, but that remains to be tested. --- .gitignore | 241 +++++++++++- dreambooth_gui.py | 437 +++++++++++----------- library/basic_caption_gui.py | 15 +- library/blip_caption_gui.py | 13 +- library/common_gui.py | 244 ++++++------ library/convert_model_gui.py | 13 +- library/dataset_balancing_gui.py | 11 +- library/dreambooth_folder_creation_gui.py | 13 +- library/extract_lora_gui.py | 21 +- library/extract_lycoris_locon_gui.py | 21 +- library/git_caption_gui.py | 11 +- library/gui_subprocesses.py | 84 +++++ library/merge_lora_gui.py | 21 +- library/resize_lora_gui.py | 19 +- library/sampler_gui.py | 1 - library/svd_merge_lora_gui.py | 21 +- library/tensorboard_gui.py | 6 +- library/verify_lora_gui.py | 13 +- library/wd14_caption_gui.py | 15 +- lora_gui.py | 40 +- textual_inversion_gui.py | 31 +- 21 files changed, 799 insertions(+), 492 deletions(-) create mode 100644 library/gui_subprocesses.py diff --git a/.gitignore b/.gitignore index 71fe116..4fc61ee 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,242 @@ -venv -__pycache__ +# Kohya_SS Specifics cudnn_windows .vscode -*.egg-info -build wd14_tagger_model .DS_Store locon gui-user.bat -gui-user.ps1 \ No newline at end of file +gui-user.ps1 +*.whl* +.idea + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser \ No newline at end of file diff --git a/dreambooth_gui.py b/dreambooth_gui.py index a72acff..6fbdbf0 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -3,13 +3,15 @@ # v3: Add new Utilities tab for Dreambooth folder preparation # v3.1: Adding captionning of images to utilities -import gradio as gr +import argparse import json import math import os -import subprocess import pathlib -import argparse +import subprocess + +import gradio as gr + from library.common_gui import ( get_folder_path, remove_doublequote, @@ -26,88 +28,87 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, + check_if_model_exist, is_valid_config, show_message_box, ) +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) +from library.sampler_gui import sample_gradio_config, run_cmd_sample from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) -from library.dreambooth_folder_creation_gui import ( - gradio_dreambooth_folder_creation_tab, -) from library.utilities import utilities_tab -from library.sampler_gui import sample_gradio_config, run_cmd_sample -from easygui import msgbox folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ save_style_symbol = '\U0001f4be' # ๐Ÿ’พ -document_symbol = '\U0001F4C4' # ๐Ÿ“„ +document_symbol = '\U0001F4C4' # ๐Ÿ“„ def save_configuration( - save_as, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -134,10 +135,10 @@ def save_configuration( name: value for name, value in parameters # locals().items() if name - not in [ - 'file_path', - 'save_as', - ] + not in [ + 'file_path', + 'save_as', + ] } # Extract the destination directory from the file path @@ -155,67 +156,67 @@ def save_configuration( def open_configuration( - ask_for_file, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -225,17 +226,20 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path(file_path) + file_path = get_file_path(file_path, filedialog_type="json") - if not file_path == '' and not file_path == None: - # load variables from JSON file + if not file_path == '' and file_path is not None: with open(file_path, 'r') as f: my_data = json.load(f) - print('Loading config...') - # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True - my_data = update_my_data(my_data) + if is_valid_config(my_data): + print('Loading config...') + my_data = update_my_data(my_data) + else: + print("Invalid configuration file.") + my_data = {} + show_message_box("Invalid configuration file.") else: - file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + file_path = original_file_path my_data = {} values = [file_path] @@ -247,85 +251,85 @@ def open_configuration( def train_model( - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training_pct, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, # Keep this. Yes, it is unused here but required given the common list used - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training_pct, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, # Keep this. Yes, it is unused here but required given the common list used + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, ): if pretrained_model_name_or_path == '': - msgbox('Source model information is missing') + show_message_box('Source model information is missing') return if train_data_dir == '': - msgbox('Image folder path is missing') + show_message_box('Image folder path is missing') return if not os.path.exists(train_data_dir): - msgbox('Image folder does not exist') + show_message_box('Image folder does not exist') return if reg_data_dir != '': if not os.path.exists(reg_data_dir): - msgbox('Regularisation folder does not exist') + show_message_box('Regularisation folder does not exist') return if output_dir == '': - msgbox('Output folder path is missing') + show_message_box('Output folder path is missing') return if check_if_model_exist(output_name, output_dir, save_model_as): @@ -351,7 +355,8 @@ def train_model( try: repeats = int(folder.split('_')[0]) except ValueError: - print('\033[33mSubfolder', folder, 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m') + print('\033[33mSubfolder', folder, + 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m') continue # Count the number of images in the folder @@ -360,12 +365,12 @@ def train_model( f for f in os.listdir(os.path.join(train_data_dir, folder)) if f.endswith('.jpg') - or f.endswith('.jpeg') - or f.endswith('.png') - or f.endswith('.webp') + or f.endswith('.jpeg') + or f.endswith('.png') + or f.endswith('.webp') ] ) - + if num_images == 0: print(f'{folder} folder contain no images, skipping...') else: @@ -525,10 +530,10 @@ def train_model( def dreambooth_tab( - train_data_dir=gr.Textbox(), - reg_data_dir=gr.Textbox(), - output_dir=gr.Textbox(), - logging_dir=gr.Textbox(), + train_data_dir=gr.Textbox(), + reg_data_dir=gr.Textbox(), + output_dir=gr.Textbox(), + logging_dir=gr.Textbox(), ): dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index b2d208d..d36dae1 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -1,8 +1,9 @@ -import gradio as gr -from easygui import msgbox -import subprocess -from .common_gui import get_folder_path, add_pre_postfix, find_replace import os +import subprocess + +import gradio as gr + +from .common_gui import get_folder_path, add_pre_postfix, find_replace def caption_images( @@ -17,11 +18,11 @@ def caption_images( ): # Check for images_dir if not images_dir: - msgbox('Image folder is missing...') + show_message_box('Image folder is missing...') return if not caption_ext: - msgbox('Please provide an extension for the caption files.') + show_message_box('Please provide an extension for the caption files.') return if caption_text: @@ -60,7 +61,7 @@ def caption_images( ) else: if prefix or postfix: - msgbox( + show_message_box( 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...' ) diff --git a/library/blip_caption_gui.py b/library/blip_caption_gui.py index 2e0081d..35fd513 100644 --- a/library/blip_caption_gui.py +++ b/library/blip_caption_gui.py @@ -1,7 +1,8 @@ -import gradio as gr -from easygui import msgbox -import subprocess import os +import subprocess + +import gradio as gr + from .common_gui import get_folder_path, add_pre_postfix PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -21,16 +22,16 @@ def caption_images( ): # Check for caption_text_input # if caption_text_input == "": - # msgbox("Caption text is missing...") + # show_message_box("Caption text is missing...") # return # Check for images_dir_input if train_data_dir == '': - msgbox('Image folder is missing...') + show_message_box('Image folder is missing...') return if caption_file_ext == '': - msgbox('Please provide an extension for the caption files.') + show_message_box('Please provide an extension for the caption files.') return print(f'Captioning files in {train_data_dir}...') diff --git a/library/common_gui.py b/library/common_gui.py index 25e7379..f47dfcc 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -1,14 +1,17 @@ -from tkinter import filedialog, Tk -from easygui import msgbox import os -import gradio as gr -import easygui import shutil +import subprocess +from tkinter import filedialog, Tk + +import easygui +import gradio as gr + +from library.gui_subprocesses import save_file_dialog folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ save_style_symbol = '\U0001f4be' # ๐Ÿ’พ -document_symbol = '\U0001F4C4' # ๐Ÿ“„ +document_symbol = '\U0001F4C4' # ๐Ÿ“„ # define a list of substrings to search for v2 base models V2_BASE_MODELS = [ @@ -34,6 +37,31 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_ENVIRONMENT'] +def open_file_dialog(initial_dir, initial_file, file_types="all"): + current_directory = os.path.dirname(os.path.abspath(__file__)) + + args = ["python", f"{current_directory}/gui_subprocesses.py", "file_dialog"] + if initial_dir: + args.append(initial_dir) + if initial_file: + args.append(initial_file) + if file_types: + args.append(file_types) + + file_path = subprocess.check_output(args).decode("utf-8").strip() + return file_path + + +def show_message_box(message, title=""): + current_directory = os.path.dirname(os.path.abspath(__file__)) + + args = ["python", f"{current_directory}/gui_subprocesses.py", "msgbox", message] + if title: + args.append(title) + + subprocess.run(args) + + def check_if_model_exist(output_name, output_dir, save_model_as): if save_model_as in ['diffusers', 'diffusers_safetendors']: ckpt_folder = os.path.join(output_dir, output_name) @@ -62,6 +90,22 @@ def check_if_model_exist(output_name, output_dir, save_model_as): return False +def is_valid_config(data): + # Check if the data is a dictionary + if not isinstance(data, dict): + return False + + # Add checks for expected keys and valid values + # For example, check if 'use_8bit_adam' is a boolean + if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): + return False + + # Add more checks for other keys as needed + + # If all checks pass, return True + return True + + def update_my_data(my_data): # Update the optimizer based on the use_8bit_adam flag use_8bit_adam = my_data.get('use_8bit_adam', False) @@ -87,8 +131,8 @@ def update_my_data(my_data): # Update model save choices due to changes for LoRA and TI training if ( - (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) - and my_data.get('save_model_as') not in ['safetensors', 'ckpt'] + (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) + and my_data.get('save_model_as') not in ['safetensors', 'ckpt'] ): message = ( 'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}' @@ -102,11 +146,6 @@ def update_my_data(my_data): return my_data -def get_dir_and_file(file_path): - dir_path, file_name = os.path.split(file_path) - return (dir_path, file_name) - - # def has_ext_files(directory, extension): # # Iterate through all the files in the directory # for file in os.listdir(directory): @@ -117,67 +156,36 @@ def get_dir_and_file(file_path): # return False -def get_file_path( - file_path='', default_extension='.json', extension_name='Config files' -): - if not any(var in os.environ for var in FILE_ENV_EXCLUSION): - current_file_path = file_path - # print(f'current file path: {current_file_path}') +def get_file_path(file_path='', filedialog_type="lora"): + current_file_path = file_path - initial_dir, initial_file = get_dir_and_file(file_path) + initial_dir, initial_file = os.path.split(file_path) + file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type) - # Create a hidden Tkinter root window - root = Tk() - root.wm_attributes('-topmost', 1) - root.withdraw() - - # Show the open file dialog and get the selected file path - file_path = filedialog.askopenfilename( - filetypes=( - (extension_name, f'*{default_extension}'), - ('All files', '*.*'), - ), - defaultextension=default_extension, - initialfile=initial_file, - initialdir=initial_dir, - ) - - # Destroy the hidden root window - root.destroy() - - # If no file is selected, use the current file path - if not file_path: - file_path = current_file_path - current_file_path = file_path - # print(f'current file path: {current_file_path}') + # If no file is selected, use the current file path + if not file_path: + file_path = current_file_path + current_file_path = file_path return file_path + def get_any_file_path(file_path=''): - if not any(var in os.environ for var in FILE_ENV_EXCLUSION): - current_file_path = file_path - # print(f'current file path: {current_file_path}') + current_file_path = file_path + # print(f'current file path: {current_file_path}') - initial_dir, initial_file = get_dir_and_file(file_path) + initial_dir, initial_file = os.path.split(file_path) + file_path = open_file_dialog(initial_dir, initial_file, "all") - root = Tk() - root.wm_attributes('-topmost', 1) - root.withdraw() - file_path = filedialog.askopenfilename( - initialdir=initial_dir, - initialfile=initial_file, - ) - root.destroy() - - if file_path == '': - file_path = current_file_path + if file_path == '': + file_path = current_file_path return file_path def remove_doublequote(file_path): - if file_path != None: + if file_path is not None: file_path = file_path.replace('"', '') return file_path @@ -196,62 +204,37 @@ def remove_doublequote(file_path): # ) -def get_folder_path(folder_path=''): - if not any(var in os.environ for var in FILE_ENV_EXCLUSION): - current_folder_path = folder_path +def get_folder_path(folder_path='', filedialog_type="directory"): + current_folder_path = folder_path - initial_dir, initial_file = get_dir_and_file(folder_path) + initial_dir, initial_file = os.path.split(folder_path) + file_path = open_file_dialog(initial_dir, initial_file, filedialog_type) - root = Tk() - root.wm_attributes('-topmost', 1) - root.withdraw() - folder_path = filedialog.askdirectory(initialdir=initial_dir) - root.destroy() - - if folder_path == '': - folder_path = current_folder_path + if folder_path == '': + folder_path = current_folder_path return folder_path def get_saveasfile_path( - file_path='', defaultextension='.json', extension_name='Config files' + file_path='', filedialog_type="json" ): - if not any(var in os.environ for var in FILE_ENV_EXCLUSION): - current_file_path = file_path - # print(f'current file path: {current_file_path}') + current_file_path = file_path - initial_dir, initial_file = get_dir_and_file(file_path) + initial_dir, initial_file = os.path.split(file_path) + save_file_path = save_file_dialog(initial_dir, initial_file, filedialog_type) - root = Tk() - root.wm_attributes('-topmost', 1) - root.withdraw() - save_file_path = filedialog.asksaveasfile( - filetypes=( - (f'{extension_name}', f'{defaultextension}'), - ('All files', '*'), - ), - defaultextension=defaultextension, - initialdir=initial_dir, - initialfile=initial_file, - ) - root.destroy() - - # print(save_file_path) - - if save_file_path == None: - file_path = current_file_path - else: - print(save_file_path.name) - file_path = save_file_path.name - - # print(file_path) + if save_file_path is None: + file_path = current_file_path + else: + print(save_file_path.name) + file_path = save_file_path.name return file_path def get_saveasfilename_path( - file_path='', extensions='*', extension_name='Config files' + file_path='', extensions='*', extension_name='Config files' ): if not any(var in os.environ for var in FILE_ENV_EXCLUSION): current_file_path = file_path @@ -280,10 +263,10 @@ def get_saveasfilename_path( def add_pre_postfix( - folder: str = '', - prefix: str = '', - postfix: str = '', - caption_file_ext: str = '.caption', + folder: str = '', + prefix: str = '', + postfix: str = '', + caption_file_ext: str = '.caption', ) -> None: """ Add prefix and/or postfix to the content of caption files within a folder. @@ -343,10 +326,10 @@ def has_ext_files(folder_path: str, file_extension: str) -> bool: def find_replace( - folder_path: str = '', - caption_file_ext: str = '.caption', - search_text: str = '', - replace_text: str = '', + folder_path: str = '', + caption_file_ext: str = '.caption', + search_text: str = '', + replace_text: str = '', ) -> None: """ Find and replace text in caption files within a folder. @@ -360,7 +343,7 @@ def find_replace( print('Running caption find/replace') if not has_ext_files(folder_path, caption_file_ext): - msgbox( + show_message_box( f'No files with extension {caption_file_ext} were found in {folder_path}...' ) return @@ -374,7 +357,7 @@ def find_replace( for caption_file in caption_files: with open( - os.path.join(folder_path, caption_file), 'r', errors='ignore' + os.path.join(folder_path, caption_file), 'r', errors='ignore' ) as f: content = f.read() @@ -386,7 +369,7 @@ def find_replace( def color_aug_changed(color_aug): if color_aug: - msgbox( + show_message_box( 'Disabling "Cache latent" because "Color augmentation" has been selected...' ) return gr.Checkbox.update(value=False, interactive=False) @@ -427,7 +410,7 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): def set_pretrained_model_name_or_path_input( - model_list, pretrained_model_name_or_path, v2, v_parameterization + model_list, pretrained_model_name_or_path, v2, v_parameterization ): # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list if str(model_list) in V2_BASE_MODELS: @@ -452,9 +435,9 @@ def set_pretrained_model_name_or_path_input( if model_list == 'custom': if ( - str(pretrained_model_name_or_path) in V1_MODELS - or str(pretrained_model_name_or_path) in V2_BASE_MODELS - or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS + str(pretrained_model_name_or_path) in V1_MODELS + or str(pretrained_model_name_or_path) in V2_BASE_MODELS + or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS ): pretrained_model_name_or_path = '' v2 = False @@ -481,12 +464,11 @@ def set_v2_checkbox(model_list, v2, v_parameterization): def set_model_list( - model_list, - pretrained_model_name_or_path, - v2, - v_parameterization, + model_list, + pretrained_model_name_or_path, + v2, + v_parameterization, ): - if not pretrained_model_name_or_path in ALL_PRESET_MODELS: model_list = 'custom' else: @@ -529,7 +511,7 @@ def gradio_config(): def get_pretrained_model_name_or_path_file( - model_list, pretrained_model_name_or_path + model_list, pretrained_model_name_or_path ): pretrained_model_name_or_path = get_any_file_path( pretrained_model_name_or_path @@ -537,13 +519,13 @@ def get_pretrained_model_name_or_path_file( set_model_list(model_list, pretrained_model_name_or_path) -def gradio_source_model(save_model_as_choices = [ - 'same as source model', - 'ckpt', - 'diffusers', - 'diffusers_safetensors', - 'safetensors', - ]): +def gradio_source_model(save_model_as_choices=[ + 'same as source model', + 'ckpt', + 'diffusers', + 'diffusers_safetensors', + 'safetensors', +]): with gr.Tab('Source model'): # Define the input elements with gr.Row(): @@ -648,9 +630,9 @@ def gradio_source_model(save_model_as_choices = [ def gradio_training( - learning_rate_value='1e-6', - lr_scheduler_value='constant', - lr_warmup_value='0', + learning_rate_value='1e-6', + lr_scheduler_value='constant', + lr_warmup_value='0', ): with gr.Row(): train_batch_size = gr.Slider( diff --git a/library/convert_model_gui.py b/library/convert_model_gui.py index aaa39b8..dd43b0e 100644 --- a/library/convert_model_gui.py +++ b/library/convert_model_gui.py @@ -1,8 +1,9 @@ -import gradio as gr -from easygui import msgbox -import subprocess import os import shutil +import subprocess + +import gradio as gr + from .common_gui import get_folder_path, get_file_path folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -22,7 +23,7 @@ def convert_model( ): # Check for caption_text_input if source_model_type == '': - msgbox('Invalid source model type') + show_message_box('Invalid source model type') return # Check if source model exist @@ -31,14 +32,14 @@ def convert_model( elif os.path.isdir(source_model_input): print('The provided model is a folder') else: - msgbox('The provided source model is neither a file nor a folder') + show_message_box('The provided source model is neither a file nor a folder') return # Check if source model exist if os.path.isdir(target_model_folder_input): print('The provided model folder exist') else: - msgbox('The provided target folder does not exist') + show_message_box('The provided target folder does not exist') return run_cmd = f'{PYTHON} "tools/convert_diffusers20_original_sd.py"' diff --git a/library/dataset_balancing_gui.py b/library/dataset_balancing_gui.py index 2e6bc98..34292f6 100644 --- a/library/dataset_balancing_gui.py +++ b/library/dataset_balancing_gui.py @@ -1,9 +1,12 @@ import os import re + import gradio as gr -from easygui import msgbox, boolbox +from easygui import boolbox + from .common_gui import get_folder_path + # def select_folder(): # # Open a file dialog to select a directory # folder = filedialog.askdirectory() @@ -16,14 +19,14 @@ def dataset_balancing(concept_repeats, folder, insecure): if not concept_repeats > 0: # Display an error message if the total number of repeats is not a valid integer - msgbox('Please enter a valid integer for the total number of repeats.') + show_message_box('Please enter a valid integer for the total number of repeats.') return concept_repeats = int(concept_repeats) # Check if folder exist if folder == '' or not os.path.isdir(folder): - msgbox('Please enter a valid folder for balancing.') + show_message_box('Please enter a valid folder for balancing.') return pattern = re.compile(r'^\d+_.+$') @@ -85,7 +88,7 @@ def dataset_balancing(concept_repeats, folder, insecure): f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...' ) - msgbox('Dataset balancing completed...') + show_message_box('Dataset balancing completed...') def warning(insecure): diff --git a/library/dreambooth_folder_creation_gui.py b/library/dreambooth_folder_creation_gui.py index b5d5ff4..53ec71a 100644 --- a/library/dreambooth_folder_creation_gui.py +++ b/library/dreambooth_folder_creation_gui.py @@ -1,8 +1,9 @@ -import gradio as gr -from easygui import diropenbox, msgbox -from .common_gui import get_folder_path -import shutil import os +import shutil + +import gradio as gr + +from .common_gui import get_folder_path def copy_info_to_Folders_tab(training_folder): @@ -39,12 +40,12 @@ def dreambooth_folder_preparation( # Check for instance prompt if util_instance_prompt_input == '': - msgbox('Instance prompt missing...') + show_message_box('Instance prompt missing...') return # Check for class prompt if util_class_prompt_input == '': - msgbox('Class prompt missing...') + show_message_box('Class prompt missing...') return # Create the training_dir path diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 5f48686..6a5df1b 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -1,11 +1,10 @@ -import gradio as gr -from easygui import msgbox -import subprocess import os +import subprocess + +import gradio as gr + from .common_gui import ( - get_saveasfilename_path, - get_any_file_path, - get_file_path, + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -27,20 +26,20 @@ def extract_lora( ): # Check for caption_text_input if model_tuned == '': - msgbox('Invalid finetuned model file') + show_message_box('Invalid finetuned model file') return if model_org == '': - msgbox('Invalid base model file') + show_message_box('Invalid base model file') return # Check if source model exist if not os.path.isfile(model_tuned): - msgbox('The provided finetuned model is not a file') + show_message_box('The provided finetuned model is not a file') return if not os.path.isfile(model_org): - msgbox('The provided base model is not a file') + show_message_box('The provided base model is not a file') return run_cmd = ( @@ -121,7 +120,7 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfilename_path, + get_saveasfile_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to, show_progress=False, diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py index 13575bb..f95bf96 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/library/extract_lycoris_locon_gui.py @@ -1,11 +1,10 @@ -import gradio as gr -from easygui import msgbox -import subprocess import os +import subprocess + +import gradio as gr + from .common_gui import ( - get_saveasfilename_path, - get_any_file_path, - get_file_path, + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -36,20 +35,20 @@ def extract_lycoris_locon( ): # Check for caption_text_input if db_model == '': - msgbox('Invalid finetuned model file') + show_message_box('Invalid finetuned model file') return if base_model == '': - msgbox('Invalid base model file') + show_message_box('Invalid base model file') return # Check if source model exist if not os.path.isfile(db_model): - msgbox('The provided finetuned model is not a file') + show_message_box('The provided finetuned model is not a file') return if not os.path.isfile(base_model): - msgbox('The provided base model is not a file') + show_message_box('The provided base model is not a file') return run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"' @@ -167,7 +166,7 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_output_name.click( - get_saveasfilename_path, + get_saveasfile_path, inputs=[output_name, lora_ext, lora_ext_name], outputs=output_name, show_progress=False, diff --git a/library/git_caption_gui.py b/library/git_caption_gui.py index 9aaf3d9..4ef87b4 100644 --- a/library/git_caption_gui.py +++ b/library/git_caption_gui.py @@ -1,7 +1,8 @@ -import gradio as gr -from easygui import msgbox -import subprocess import os +import subprocess + +import gradio as gr + from .common_gui import get_folder_path, add_pre_postfix PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -19,11 +20,11 @@ def caption_images( ): # Check for images_dir_input if train_data_dir == '': - msgbox('Image folder is missing...') + show_message_box('Image folder is missing...') return if caption_ext == '': - msgbox('Please provide an extension for the caption files.') + show_message_box('Please provide an extension for the caption files.') return print(f'GIT captioning files in {train_data_dir}...') diff --git a/library/gui_subprocesses.py b/library/gui_subprocesses.py new file mode 100644 index 0000000..87e65fc --- /dev/null +++ b/library/gui_subprocesses.py @@ -0,0 +1,84 @@ +import sys +import tkinter as tk +from tkinter import filedialog, messagebox + + +def open_file_dialog(initial_dir=None, initial_file=None, file_types="all"): + file_type_filters = { + "all": [("All files", "*.*")], + "video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")], + "images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")], + "json": [("JSON files", "*.json")], + "lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], + "directory": [], + } + + if file_types in file_type_filters: + filters = file_type_filters[file_types] + else: + filters = file_type_filters["all"] + + if file_types == "directory": + return filedialog.askdirectory(initialdir=initial_dir) + else: + return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) + + +def save_file_dialog(initial_dir, initial_file, files_type="all"): + root = tk.Tk() + root.withdraw() + + filetypes_switch = { + "all": [("All files", "*.*")], + "video": [("Video files", "*.mp4;*.avi;*.mkv;*.webm;*.flv;*.mov;*.wmv")], + "images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff;*.ico")], + "json": [("JSON files", "*.json")], + "lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], + } + + filetypes = filetypes_switch.get(files_type, filetypes_switch["all"]) + save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filetypes, + defaultextension=filetypes) + + root.destroy() + + return save_file_path + + +def show_message_box(_message, _title="Message", _level="info"): + root = tk.Tk() + root.withdraw() + + message_type = { + "warning": messagebox.showwarning, + "error": messagebox.showerror, + "info": messagebox.showinfo, + "question": messagebox.askquestion, + "okcancel": messagebox.askokcancel, + "retrycancel": messagebox.askretrycancel, + "yesno": messagebox.askyesno, + "yesnocancel": messagebox.askyesnocancel + } + + if _level in message_type: + message_type[_level](_title, _message) + else: + messagebox.showinfo(_title, _message) + + root.destroy() + + +if __name__ == '__main__': + mode = sys.argv[1] + + if mode == 'file_dialog': + starting_dir = sys.argv[2] if len(sys.argv) > 2 else None + starting_file = sys.argv[3] if len(sys.argv) > 3 else None + file_class = sys.argv[2] if len(sys.argv) > 2 else None + file_path = open_file_dialog(starting_dir, starting_file, file_class) + print(file_path) + + elif mode == 'msgbox': + message = sys.argv[2] + title = sys.argv[3] if len(sys.argv) > 3 else "" + show_message_box(message, title) diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index 21cd16a..e5f6b9b 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -1,11 +1,10 @@ -import gradio as gr -from easygui import msgbox -import subprocess import os +import subprocess + +import gradio as gr + from .common_gui import ( - get_saveasfilename_path, - get_any_file_path, - get_file_path, + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -25,20 +24,20 @@ def merge_lora( ): # Check for caption_text_input if lora_a_model == '': - msgbox('Invalid model A file') + show_message_box('Invalid model A file') return if lora_b_model == '': - msgbox('Invalid model B file') + show_message_box('Invalid model B file') return # Check if source model exist if not os.path.isfile(lora_a_model): - msgbox('The provided model A is not a file') + show_message_box('The provided model A is not a file') return if not os.path.isfile(lora_b_model): - msgbox('The provided model B is not a file') + show_message_box('The provided model B is not a file') return ratio_a = ratio @@ -122,7 +121,7 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfilename_path, + get_saveasfile_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to, show_progress=False, diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index 05a7734..a47e407 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -1,8 +1,9 @@ -import gradio as gr -from easygui import msgbox -import subprocess import os -from .common_gui import get_saveasfilename_path, get_file_path +import subprocess + +import gradio as gr + +from .common_gui import get_file_path, get_saveasfile_path PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -23,24 +24,24 @@ def resize_lora( ): # Check for caption_text_input if model == '': - msgbox('Invalid model file') + show_message_box('Invalid model file') return # Check if source model exist if not os.path.isfile(model): - msgbox('The provided model is not a file') + show_message_box('The provided model is not a file') return if dynamic_method == 'sv_ratio': if float(dynamic_param) < 2: - msgbox( + show_message_box( f'Dynamic parameter for {dynamic_method} need to be 2 or greater...' ) return if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative': if float(dynamic_param) < 0 or float(dynamic_param) > 1: - msgbox( + show_message_box( f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...' ) return @@ -134,7 +135,7 @@ def gradio_resize_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfilename_path, + get_saveasfile_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to, show_progress=False, diff --git a/library/sampler_gui.py b/library/sampler_gui.py index ce95313..c1ba146 100644 --- a/library/sampler_gui.py +++ b/library/sampler_gui.py @@ -1,7 +1,6 @@ import tempfile import os import gradio as gr -from easygui import msgbox folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index 042be2e..a2b040c 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -1,11 +1,10 @@ -import gradio as gr -from easygui import msgbox -import subprocess import os +import subprocess + +import gradio as gr + from .common_gui import ( - get_saveasfilename_path, - get_any_file_path, - get_file_path, + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -28,20 +27,20 @@ def svd_merge_lora( ): # Check for caption_text_input if lora_a_model == '': - msgbox('Invalid model A file') + show_message_box('Invalid model A file') return if lora_b_model == '': - msgbox('Invalid model B file') + show_message_box('Invalid model B file') return # Check if source model exist if not os.path.isfile(lora_a_model): - msgbox('The provided model A is not a file') + show_message_box('The provided model A is not a file') return if not os.path.isfile(lora_b_model): - msgbox('The provided model B is not a file') + show_message_box('The provided model B is not a file') return ratio_a = ratio @@ -144,7 +143,7 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfilename_path, + get_saveasfile_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to, show_progress=False, diff --git a/library/tensorboard_gui.py b/library/tensorboard_gui.py index d08a02d..f2dd74f 100644 --- a/library/tensorboard_gui.py +++ b/library/tensorboard_gui.py @@ -1,9 +1,9 @@ import os -import gradio as gr -from easygui import msgbox import subprocess import time +import gradio as gr + tensorboard_proc = None # I know... bad but heh TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe' @@ -13,7 +13,7 @@ def start_tensorboard(logging_dir): if not os.listdir(logging_dir): print('Error: log folder is empty') - msgbox(msg='Error: log folder is empty') + show_message_box(msg='Error: log folder is empty') return run_cmd = [f'{TENSORBOARD}', '--logdir', f'{logging_dir}'] diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py index a7a0bf9..a72160e 100644 --- a/library/verify_lora_gui.py +++ b/library/verify_lora_gui.py @@ -1,10 +1,9 @@ -import gradio as gr -from easygui import msgbox -import subprocess import os +import subprocess + +import gradio as gr + from .common_gui import ( - get_saveasfilename_path, - get_any_file_path, get_file_path, ) @@ -20,12 +19,12 @@ def verify_lora( ): # verify for caption_text_input if lora_model == '': - msgbox('Invalid model A file') + show_message_box('Invalid model A file') return # verify if source model exist if not os.path.isfile(lora_model): - msgbox('The provided model A is not a file') + show_message_box('The provided model A is not a file') return run_cmd = [ diff --git a/library/wd14_caption_gui.py b/library/wd14_caption_gui.py index 1970849..a9f7742 100644 --- a/library/wd14_caption_gui.py +++ b/library/wd14_caption_gui.py @@ -1,8 +1,9 @@ -import gradio as gr -from easygui import msgbox -import subprocess -from .common_gui import get_folder_path import os +import subprocess + +import gradio as gr + +from .common_gui import get_folder_path def replace_underscore_with_space(folder_path, file_extension): @@ -20,16 +21,16 @@ def caption_images( ): # Check for caption_text_input # if caption_text_input == "": - # msgbox("Caption text is missing...") + # show_message_box("Caption text is missing...") # return # Check for images_dir_input if train_data_dir == '': - msgbox('Image folder is missing...') + show_message_box('Image folder is missing...') return if caption_extension == '': - msgbox('Please provide an extension for the caption files.') + show_message_box('Please provide an extension for the caption files.') return print(f'Captioning files in {train_data_dir}...') diff --git a/lora_gui.py b/lora_gui.py index e7ada89..819927c 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -3,14 +3,15 @@ # v3: Add new Utilities tab for Dreambooth folder preparation # v3.1: Adding captionning of images to utilities -import gradio as gr -import easygui +import argparse import json import math import os -import subprocess import pathlib -import argparse +import subprocess + +import gradio as gr + from library.common_gui import ( get_folder_path, remove_doublequote, @@ -27,24 +28,23 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, + check_if_model_exist, show_message_box, ) +from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) +from library.merge_lora_gui import gradio_merge_lora_tab +from library.resize_lora_gui import gradio_resize_lora_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample +from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) -from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab -from library.merge_lora_gui import gradio_merge_lora_tab -from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab from library.verify_lora_gui import gradio_verify_lora_tab -from library.resize_lora_gui import gradio_resize_lora_tab -from library.sampler_gui import sample_gradio_config, run_cmd_sample -from easygui import msgbox folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ @@ -353,35 +353,35 @@ def train_model( print_only_bool = True if print_only.get('label') == 'True' else False if pretrained_model_name_or_path == '': - msgbox('Source model information is missing') + show_message_box('Source model information is missing') return if train_data_dir == '': - msgbox('Image folder path is missing') + show_message_box('Image folder path is missing') return if not os.path.exists(train_data_dir): - msgbox('Image folder does not exist') + show_message_box('Image folder does not exist') return if reg_data_dir != '': if not os.path.exists(reg_data_dir): - msgbox('Regularisation folder does not exist') + show_message_box('Regularisation folder does not exist') return if output_dir == '': - msgbox('Output folder path is missing') + show_message_box('Output folder path is missing') return if int(bucket_reso_steps) < 1: - msgbox('Bucket resolution steps need to be greater than 0') + show_message_box('Bucket resolution steps need to be greater than 0') return if not os.path.exists(output_dir): os.makedirs(output_dir) if stop_text_encoder_training_pct > 0: - msgbox( + show_message_box( 'Output "stop text encoder training" is not yet supported. Ignoring' ) stop_text_encoder_training_pct = 0 @@ -396,7 +396,7 @@ def train_model( unet_lr = 0 # if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0): - # msgbox( + # show_message_box( # 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided' # ) # return @@ -532,7 +532,7 @@ def train_model( run_cmd += f' --network_train_unet_only' else: if float(text_encoder_lr) == 0: - msgbox('Please input learning rate values.') + show_message_box('Please input learning rate values.') return run_cmd += f' --network_dim={network_dim}' diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 7b8c19c..f48904a 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -3,13 +3,15 @@ # v3: Add new Utilities tab for Dreambooth folder preparation # v3.1: Adding captionning of images to utilities -import gradio as gr +import argparse import json import math import os -import subprocess import pathlib -import argparse +import subprocess + +import gradio as gr + from library.common_gui import ( get_folder_path, remove_doublequote, @@ -28,17 +30,16 @@ from library.common_gui import ( update_my_data, check_if_model_exist, ) +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) +from library.sampler_gui import sample_gradio_config, run_cmd_sample from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) -from library.dreambooth_folder_creation_gui import ( - gradio_dreambooth_folder_creation_tab, -) from library.utilities import utilities_tab -from library.sampler_gui import sample_gradio_config, run_cmd_sample -from easygui import msgbox folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ @@ -323,32 +324,32 @@ def train_model( additional_parameters,vae_batch_size, ): if pretrained_model_name_or_path == '': - msgbox('Source model information is missing') + show_message_box('Source model information is missing') return if train_data_dir == '': - msgbox('Image folder path is missing') + show_message_box('Image folder path is missing') return if not os.path.exists(train_data_dir): - msgbox('Image folder does not exist') + show_message_box('Image folder does not exist') return if reg_data_dir != '': if not os.path.exists(reg_data_dir): - msgbox('Regularisation folder does not exist') + show_message_box('Regularisation folder does not exist') return if output_dir == '': - msgbox('Output folder path is missing') + show_message_box('Output folder path is missing') return if token_string == '': - msgbox('Token string is missing') + show_message_box('Token string is missing') return if init_word == '': - msgbox('Init word is missing') + show_message_box('Init word is missing') return if not os.path.exists(output_dir): From e5b83df675c32d165cf1f6a3c8e536f57875c5ee Mon Sep 17 00:00:00 2001 From: JSTayco Date: Thu, 30 Mar 2023 13:13:25 -0700 Subject: [PATCH 2/4] Removed one warning dealing with get_file_path() Using lambdas now to pass in variable amount of arguments from components. This works right now with a few open windows, but saving and possibly loading will be broken right now. They need the lambda treatment next. I also split the JSON validation placeholder to library/common_utilities.py. --- dreambooth_gui.py | 12 ++++++---- finetune_gui.py | 10 ++++---- library/common_gui.py | 35 ++++++++++++++-------------- library/common_utilities.py | 14 +++++++++++ library/convert_model_gui.py | 31 ++++++++++++------------ library/extract_lora_gui.py | 28 +++++++++++----------- library/extract_lycoris_locon_gui.py | 8 ++++--- library/merge_lora_gui.py | 8 ++++--- library/resize_lora_gui.py | 5 ++-- library/svd_merge_lora_gui.py | 8 ++++--- library/verify_lora_gui.py | 5 ++-- lora_gui.py | 8 +++---- textual_inversion_gui.py | 10 ++++---- 13 files changed, 105 insertions(+), 77 deletions(-) create mode 100644 library/common_utilities.py diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 073f23d..4af433e 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -28,8 +28,9 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, is_valid_config, show_message_box, + check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, ) +from library.common_utilities import is_valid_config from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) @@ -228,7 +229,8 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path(file_path, filedialog_type="json") + print(f"File path: {file_path}") + file_path = get_file_path_gradio_wrapper(file_path) if not file_path == '' and file_path is not None: with open(file_path, 'r') as f: @@ -836,15 +838,15 @@ def dreambooth_tab( ] button_open_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(*args), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - open_configuration, - inputs=[dummy_db_false, config_file_name] + settings_list, + lambda *args, **kwargs: open_configuration(*args), + inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) diff --git a/finetune_gui.py b/finetune_gui.py index b085928..18c77f9 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -20,7 +20,7 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, + check_if_model_exist, get_file_path_gradio_wrapper, ) from library.tensorboard_gui import ( gradio_tensorboard, @@ -231,9 +231,9 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path(file_path) + file_path = get_file_path_gradio_wrapper(file_path) - if not file_path == '' and not file_path == None: + if not file_path == '' and file_path is not None: # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) @@ -799,14 +799,14 @@ def finetune_tab(): button_run.click(train_model, inputs=settings_list) button_open_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, diff --git a/library/common_gui.py b/library/common_gui.py index 857da9b..4f7fb93 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -89,23 +89,6 @@ def check_if_model_exist(output_name, output_dir, save_model_as): return False - -def is_valid_config(data): - # Check if the data is a dictionary - if not isinstance(data, dict): - return False - - # Add checks for expected keys and valid values - # For example, check if 'use_8bit_adam' is a boolean - if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): - return False - - # Add more checks for other keys as needed - - # If all checks pass, return True - return True - - def update_my_data(my_data): # Update the optimizer based on the use_8bit_adam flag use_8bit_adam = my_data.get('use_8bit_adam', False) @@ -155,6 +138,24 @@ def update_my_data(my_data): # # If no extension files were found, return False # return False +def get_file_path_gradio_wrapper(file_path, filedialog_type="all"): + file_extension = os.path.splitext(file_path)[-1].lower() + + filetype_filters = { + 'db': ['.db'], + 'json': ['.json'], + 'lora': ['.pt', '.ckpt', '.safetensors'], + } + + # Find the appropriate filedialog_type based on the file extension + filedialog_type = 'all' + for key, extensions in filetype_filters.items(): + if file_extension in extensions: + filedialog_type = key + break + + return get_file_path(file_path, filedialog_type) + def get_file_path(file_path='', filedialog_type="lora"): current_file_path = file_path diff --git a/library/common_utilities.py b/library/common_utilities.py new file mode 100644 index 0000000..ea1979c --- /dev/null +++ b/library/common_utilities.py @@ -0,0 +1,14 @@ +def is_valid_config(data): + # Check if the data is a dictionary + if not isinstance(data, dict): + return False + + # Add checks for expected keys and valid values + # For example, check if 'use_8bit_adam' is a boolean + if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): + return False + + # Add more checks for other keys as needed + + # If all checks pass, return True + return True diff --git a/library/convert_model_gui.py b/library/convert_model_gui.py index dd43b0e..2450eb5 100644 --- a/library/convert_model_gui.py +++ b/library/convert_model_gui.py @@ -4,22 +4,22 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, get_file_path +from .common_gui import get_folder_path, get_file_path, get_file_path_gradio_wrapper folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ save_style_symbol = '\U0001f4be' # ๐Ÿ’พ -document_symbol = '\U0001F4C4' # ๐Ÿ“„ +document_symbol = '\U0001F4C4' # ๐Ÿ“„ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' def convert_model( - source_model_input, - source_model_type, - target_model_folder_input, - target_model_name_input, - target_model_type, - target_save_precision_type, + source_model_input, + source_model_type, + target_model_folder_input, + target_model_name_input, + target_model_type, + target_save_precision_type, ): # Check for caption_text_input if source_model_type == '': @@ -61,8 +61,8 @@ def convert_model( run_cmd += f' --{target_save_precision_type}' if ( - target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' ): run_cmd += f' --reference_model="{source_model_type}"' @@ -72,8 +72,8 @@ def convert_model( run_cmd += f' "{source_model_input}"' if ( - target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' ): target_model_path = os.path.join( target_model_folder_input, target_model_name_input @@ -95,8 +95,8 @@ def convert_model( subprocess.run(run_cmd) if ( - not target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' + not target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' ): v2_models = [ @@ -180,7 +180,8 @@ def gradio_convert_model_tab(): document_symbol, elem_id='open_folder_small' ) button_source_model_file.click( - get_file_path, + lambda input1, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.normpath(input1)), inputs=[source_model_input], outputs=source_model_input, show_progress=False, diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 6bc1f30..dbad3d1 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -4,25 +4,25 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, get_saveasfile_path, + get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ save_style_symbol = '\U0001f4be' # ๐Ÿ’พ -document_symbol = '\U0001F4C4' # ๐Ÿ“„ +document_symbol = '\U0001F4C4' # ๐Ÿ“„ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' def extract_lora( - model_tuned, - model_org, - save_to, - save_precision, - dim, - v2, - conv_dim, - device, + model_tuned, + model_org, + save_to, + save_precision, + dim, + v2, + conv_dim, + device, ): # Check for caption_text_input if model_tuned == '': @@ -43,7 +43,7 @@ def extract_lora( return run_cmd = ( - f'{PYTHON} "{os.path.join("networks","extract_lora_from_models.py")}"' + f'{PYTHON} "{os.path.join("networks", "extract_lora_from_models.py")}"' ) run_cmd += f' --save_precision {save_precision}' run_cmd += f' --save_to "{save_to}"' @@ -90,7 +90,8 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_model_tuned_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[model_tuned, model_ext, model_ext_name], outputs=model_tuned, show_progress=False, @@ -105,7 +106,8 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_model_org_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[model_org, model_ext, model_ext_name], outputs=model_org, show_progress=False, diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py index f95bf96..491e8ac 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/library/extract_lycoris_locon_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, get_saveasfile_path, + get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -136,7 +136,8 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_db_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[db_model, model_ext, model_ext_name], outputs=db_model, show_progress=False, @@ -151,7 +152,8 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_base_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[base_model, model_ext, model_ext_name], outputs=base_model, show_progress=False, diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index e5f6b9b..1a0edf9 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, get_saveasfile_path, + get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -81,7 +81,8 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -96,7 +97,8 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index a47e407..e8321d8 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_file_path, get_saveasfile_path +from .common_gui import get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -96,7 +96,8 @@ def gradio_resize_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[model, lora_ext, lora_ext_name], outputs=model, show_progress=False, diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index a2b040c..be127b3 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, get_saveasfile_path, + get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -87,7 +87,8 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -102,7 +103,8 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py index a72160e..4acf101 100644 --- a/library/verify_lora_gui.py +++ b/library/verify_lora_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr from .common_gui import ( - get_file_path, + get_file_path, get_file_path_gradio_wrapper, ) PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -68,7 +68,8 @@ def gradio_verify_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_model_file.click( - get_file_path, + lambda input1, input2, input3, *args, **kwargs: + get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), inputs=[lora_model, lora_ext, lora_ext_name], outputs=lora_model, show_progress=False, diff --git a/lora_gui.py b/lora_gui.py index 0c8129e..8990458 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -28,7 +28,7 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, show_message_box, + check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, ) from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dreambooth_folder_creation_gui import ( @@ -254,7 +254,7 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path(file_path) + file_path = get_file_path_gradio_wrapper(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -1031,14 +1031,14 @@ def lora_tab( ] button_open_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, ) button_load_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 5c82818..f434494 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -28,7 +28,7 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, + check_if_model_exist, get_file_path_gradio_wrapper, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -240,7 +240,7 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path(file_path) + file_path = get_file_path_gradio_wrapper(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -673,7 +673,7 @@ def ti_tab( ) weights_file_input = gr.Button('๐Ÿ“‚', elem_id='open_folder_small') weights_file_input.click( - get_file_path, + lambda *args, **kwargs: get_file_path_gradio_wrapper, outputs=weights, show_progress=False, ) @@ -899,14 +899,14 @@ def ti_tab( ] button_open_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - open_configuration, + lambda *args, **kwargs: open_configuration(), inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, From eef5becab8d3a2d76b05fc4533da8f59b2168e57 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 14:39:10 -0700 Subject: [PATCH 3/4] All sorts of broken, but I need to commit this for now so I don't lose it. WIP: Using some OOP to reduce imports and centralize some code. No need to remake Tk Windows everywhere for example. Renamed common_gui to common_gui_functions.py to make some of the new code separation more obvious. --- dreambooth_gui.py | 15 +- fine_tune.py | 10 +- finetune_gui.py | 17 +-- kohya_gui.py | 13 +- library/basic_caption_gui.py | 2 +- library/blip_caption_gui.py | 2 +- ...{common_gui.py => common_gui_functions.py} | 53 +++++--- library/common_utilities.py | 32 +++-- library/{config_util.py => config_ml_util.py} | 10 +- library/convert_model_gui.py | 5 +- library/dataset_balancing_gui.py | 2 +- library/dreambooth_folder_creation_gui.py | 2 +- library/extract_lora_gui.py | 9 +- library/extract_lycoris_locon_gui.py | 9 +- library/git_caption_gui.py | 2 +- library/gui_subprocesses.py | 128 +++++++++--------- library/merge_lora_gui.py | 10 +- library/resize_lora_gui.py | 5 +- library/svd_merge_lora_gui.py | 10 +- library/verify_lora_gui.py | 7 +- library/wd14_caption_gui.py | 2 +- lora_gui.py | 6 +- requirements_macos.txt | 2 +- textual_inversion_gui.py | 10 +- train_db - Copy.py | 16 +-- train_db.py | 13 +- train_network - Copy.py | 23 ++-- train_network.py | 23 ++-- train_textual_inversion - Copy.py | 11 +- train_textual_inversion.py | 11 +- 30 files changed, 234 insertions(+), 226 deletions(-) rename library/{common_gui.py => common_gui_functions.py} (96%) rename library/{config_util.py => config_ml_util.py} (100%) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 4af433e..e54704f 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -9,10 +9,11 @@ import math import os import pathlib import subprocess +import sys import gradio as gr -from library.common_gui import ( +from library.common_gui_functions import ( get_folder_path, remove_doublequote, get_file_path, @@ -28,9 +29,9 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, + check_if_model_exist, show_message_box, ) -from library.common_utilities import is_valid_config +from library.common_utilities import CommonUtilities from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) @@ -230,12 +231,12 @@ def open_configuration( if ask_for_file: print(f"File path: {file_path}") - file_path = get_file_path_gradio_wrapper(file_path) + file_path = get_file_path(file_path, filedialog_type="json") if not file_path == '' and file_path is not None: with open(file_path, 'r') as f: my_data = json.load(f) - if is_valid_config(my_data): + if CommonUtilities.is_valid_config(my_data): print('Loading config...') my_data = update_my_data(my_data) else: @@ -838,14 +839,14 @@ def dreambooth_tab( ] button_open_config.click( - lambda *args, **kwargs: open_configuration(*args), + lambda *_args, **kwargs: open_configuration(*_args, **kwargs), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - lambda *args, **kwargs: open_configuration(*args), + lambda *args, **kwargs: open_configuration(*args, **kwargs), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, diff --git a/fine_tune.py b/fine_tune.py index 637a729..ad298fd 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -5,22 +5,20 @@ import argparse import gc import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight diff --git a/finetune_gui.py b/finetune_gui.py index 18c77f9..4c6e5ce 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -1,17 +1,18 @@ -import gradio as gr +import argparse import json import math import os -import subprocess import pathlib -import argparse -from library.common_gui import ( +import subprocess + +import gradio as gr + +from library.common_gui_functions import ( get_folder_path, get_file_path, get_saveasfile_path, save_inference_file, gradio_advanced_training, - run_cmd_advanced_training, gradio_training, run_cmd_advanced_training, gradio_config, @@ -20,15 +21,15 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, get_file_path_gradio_wrapper, + check_if_model_exist, ) +from library.sampler_gui import sample_gradio_config, run_cmd_sample from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) from library.utilities import utilities_tab -from library.sampler_gui import sample_gradio_config, run_cmd_sample folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ @@ -231,7 +232,7 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path_gradio_wrapper(file_path) + file_path = get_file_path(file_path) if not file_path == '' and file_path is not None: # load variables from JSON file diff --git a/kohya_gui.py b/kohya_gui.py index f8e0d8c..732cd78 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -1,17 +1,18 @@ -import gradio as gr -import os import argparse +import os +from pathlib import Path + +import gradio as gr + from dreambooth_gui import dreambooth_tab from finetune_gui import finetune_tab -from textual_inversion_gui import ti_tab -from library.utilities import utilities_tab from library.extract_lora_gui import gradio_extract_lora_tab from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab from library.merge_lora_gui import gradio_merge_lora_tab from library.resize_lora_gui import gradio_resize_lora_tab +from library.utilities import utilities_tab from lora_gui import lora_tab - - +from textual_inversion_gui import ti_tab def UI(**kwargs): css = '' diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index d36dae1..0fa7ba6 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, add_pre_postfix, find_replace +from .common_gui_functions import get_folder_path, add_pre_postfix, find_replace def caption_images( diff --git a/library/blip_caption_gui.py b/library/blip_caption_gui.py index 35fd513..ffe087b 100644 --- a/library/blip_caption_gui.py +++ b/library/blip_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, add_pre_postfix +from .common_gui_functions import get_folder_path, add_pre_postfix PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' diff --git a/library/common_gui.py b/library/common_gui_functions.py similarity index 96% rename from library/common_gui.py rename to library/common_gui_functions.py index 4f7fb93..77f4923 100644 --- a/library/common_gui.py +++ b/library/common_gui_functions.py @@ -1,12 +1,14 @@ import os import shutil import subprocess +from contextlib import contextmanager +import tkinter as tk from tkinter import filedialog, Tk import easygui import gradio as gr -from library.gui_subprocesses import save_file_dialog +from library.common_utilities import CommonUtilities folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ @@ -37,6 +39,16 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID'] +@contextmanager +def tk_context(): + root = tk.Tk() + root.withdraw() + try: + yield root + finally: + root.destroy() + + def open_file_dialog(initial_dir, initial_file, file_types="all"): current_directory = os.path.dirname(os.path.abspath(__file__)) @@ -89,6 +101,7 @@ def check_if_model_exist(output_name, output_dir, save_model_as): return False + def update_my_data(my_data): # Update the optimizer based on the use_8bit_adam flag use_8bit_adam = my_data.get('use_8bit_adam', False) @@ -138,28 +151,37 @@ def update_my_data(my_data): # # If no extension files were found, return False # return False -def get_file_path_gradio_wrapper(file_path, filedialog_type="all"): +# def get_file_path_gradio_wrapper(file_path, filedialog_type="all"): +# file_extension = os.path.splitext(file_path)[-1].lower() +# +# filetype_filters = { +# 'db': ['.db'], +# 'json': ['.json'], +# 'lora': ['.pt', '.ckpt', '.safetensors'], +# } +# +# # Find the appropriate filedialog_type based on the file extension +# filedialog_type = 'all' +# for key, extensions in filetype_filters.items(): +# if file_extension in extensions: +# filedialog_type = key +# break +# +# return get_file_path(file_path, filedialog_type) + + +def get_file_path(file_path='', filedialog_type="lora"): file_extension = os.path.splitext(file_path)[-1].lower() - filetype_filters = { - 'db': ['.db'], - 'json': ['.json'], - 'lora': ['.pt', '.ckpt', '.safetensors'], - } - # Find the appropriate filedialog_type based on the file extension - filedialog_type = 'all' - for key, extensions in filetype_filters.items(): + for key, extensions in CommonUtilities.file_filters.items(): if file_extension in extensions: filedialog_type = key break - return get_file_path(file_path, filedialog_type) - - -def get_file_path(file_path='', filedialog_type="lora"): current_file_path = file_path + print(f"File type: {filedialog_type}") initial_dir, initial_file = os.path.split(file_path) file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type) @@ -171,7 +193,6 @@ def get_file_path(file_path='', filedialog_type="lora"): return file_path - def get_any_file_path(file_path=''): current_file_path = file_path # print(f'current file path: {current_file_path}') @@ -823,7 +844,7 @@ def gradio_advanced_training(): xformers = gr.Checkbox(label='Use xformers', value=True) color_aug = gr.Checkbox(label='Color augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False) - min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1) + min_snr_gamma = gr.Slider(label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1) with gr.Row(): bucket_no_upscale = gr.Checkbox( label="Don't upscale bucket resolution", value=True diff --git a/library/common_utilities.py b/library/common_utilities.py index ea1979c..737828b 100644 --- a/library/common_utilities.py +++ b/library/common_utilities.py @@ -1,14 +1,24 @@ -def is_valid_config(data): - # Check if the data is a dictionary - if not isinstance(data, dict): - return False +class CommonUtilities: + file_filters = { + "all": [("All files", "*.*")], + "video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")], + "images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")], + "json": [("JSON files", "*.json")], + "lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], + "directory": [], + } - # Add checks for expected keys and valid values - # For example, check if 'use_8bit_adam' is a boolean - if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): - return False + def is_valid_config(self, data): + # Check if the data is a dictionary + if not isinstance(data, dict): + return False - # Add more checks for other keys as needed + # Add checks for expected keys and valid values + # For example, check if 'use_8bit_adam' is a boolean + if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): + return False - # If all checks pass, return True - return True + # Add more checks for other keys as needed + + # If all checks pass, return True + return True diff --git a/library/config_util.py b/library/config_ml_util.py similarity index 100% rename from library/config_util.py rename to library/config_ml_util.py index 97bbb4a..af35896 100644 --- a/library/config_util.py +++ b/library/config_ml_util.py @@ -1,13 +1,13 @@ import argparse +import functools +import json +import random from dataclasses import ( asdict, dataclass, ) -import functools -import random -from textwrap import dedent, indent -import json from pathlib import Path +from textwrap import dedent, indent # from toolz import curry from typing import ( List, @@ -19,6 +19,7 @@ from typing import ( import toml import voluptuous +from transformers import CLIPTokenizer from voluptuous import ( Any, ExactSequence, @@ -27,7 +28,6 @@ from voluptuous import ( Required, Schema, ) -from transformers import CLIPTokenizer from . import train_util from .train_util import ( diff --git a/library/convert_model_gui.py b/library/convert_model_gui.py index 2450eb5..580833d 100644 --- a/library/convert_model_gui.py +++ b/library/convert_model_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, get_file_path, get_file_path_gradio_wrapper +from .common_gui_functions import get_folder_path, get_file_path folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ @@ -180,8 +180,7 @@ def gradio_convert_model_tab(): document_symbol, elem_id='open_folder_small' ) button_source_model_file.click( - lambda input1, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.normpath(input1)), + lambda *args, **kwargs: get_file_path(*args), inputs=[source_model_input], outputs=source_model_input, show_progress=False, diff --git a/library/dataset_balancing_gui.py b/library/dataset_balancing_gui.py index 34292f6..f21319b 100644 --- a/library/dataset_balancing_gui.py +++ b/library/dataset_balancing_gui.py @@ -4,7 +4,7 @@ import re import gradio as gr from easygui import boolbox -from .common_gui import get_folder_path +from .common_gui_functions import get_folder_path # def select_folder(): diff --git a/library/dreambooth_folder_creation_gui.py b/library/dreambooth_folder_creation_gui.py index 53ec71a..e554930 100644 --- a/library/dreambooth_folder_creation_gui.py +++ b/library/dreambooth_folder_creation_gui.py @@ -3,7 +3,7 @@ import shutil import gradio as gr -from .common_gui import get_folder_path +from .common_gui_functions import get_folder_path def copy_info_to_Folders_tab(training_folder): diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index dbad3d1..b54e4ff 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -90,8 +90,7 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_model_tuned_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[model_tuned, model_ext, model_ext_name], outputs=model_tuned, show_progress=False, @@ -107,7 +106,7 @@ def gradio_extract_lora_tab(): ) button_model_org_file.click( lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[model_org, model_ext, model_ext_name], outputs=model_org, show_progress=False, diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py index 491e8ac..e8a620b 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/library/extract_lycoris_locon_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -137,7 +137,7 @@ def gradio_extract_lycoris_locon_tab(): ) button_db_model_file.click( lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[db_model, model_ext, model_ext_name], outputs=db_model, show_progress=False, @@ -152,8 +152,7 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_base_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[base_model, model_ext, model_ext_name], outputs=base_model, show_progress=False, diff --git a/library/git_caption_gui.py b/library/git_caption_gui.py index 4ef87b4..a4cc6d0 100644 --- a/library/git_caption_gui.py +++ b/library/git_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, add_pre_postfix +from .common_gui_functions import get_folder_path, add_pre_postfix PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' diff --git a/library/gui_subprocesses.py b/library/gui_subprocesses.py index 87e65fc..2cbdaf2 100644 --- a/library/gui_subprocesses.py +++ b/library/gui_subprocesses.py @@ -1,84 +1,82 @@ +import os +import pathlib import sys import tkinter as tk from tkinter import filedialog, messagebox - -def open_file_dialog(initial_dir=None, initial_file=None, file_types="all"): - file_type_filters = { - "all": [("All files", "*.*")], - "video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")], - "images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")], - "json": [("JSON files", "*.json")], - "lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], - "directory": [], - } - - if file_types in file_type_filters: - filters = file_type_filters[file_types] - else: - filters = file_type_filters["all"] - - if file_types == "directory": - return filedialog.askdirectory(initialdir=initial_dir) - else: - return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) +from library.common_gui_functions import tk_context +from library.common_utilities import CommonUtilities -def save_file_dialog(initial_dir, initial_file, files_type="all"): - root = tk.Tk() - root.withdraw() +class TkGui: + def __init__(self): + self.file_types = None - filetypes_switch = { - "all": [("All files", "*.*")], - "video": [("Video files", "*.mp4;*.avi;*.mkv;*.webm;*.flv;*.mov;*.wmv")], - "images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff;*.ico")], - "json": [("JSON files", "*.json")], - "lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], - } + def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"): + print(f"File types: {self.file_types}") + with tk_context(): + self.file_types = file_types + if self.file_types in CommonUtilities.file_filters: + filters = CommonUtilities.file_filters[self.file_types] + else: + filters = CommonUtilities.file_filters["all"] - filetypes = filetypes_switch.get(files_type, filetypes_switch["all"]) - save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filetypes, - defaultextension=filetypes) + if self.file_types == "directory": + return filedialog.askdirectory(initialdir=initial_dir) + else: + return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) - root.destroy() + def save_file_dialog(self, initial_dir, initial_file, file_types="all"): + self.file_types = file_types - return save_file_path + # Use the tk_context function with the 'with' statement + with tk_context(): + if self.file_types in CommonUtilities.file_filters: + filters = CommonUtilities.file_filters[self.file_types] + else: + filters = CommonUtilities.file_filters["all"] + save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file, + filetypes=filters, defaultextension=".safetensors") -def show_message_box(_message, _title="Message", _level="info"): - root = tk.Tk() - root.withdraw() + return save_file_path - message_type = { - "warning": messagebox.showwarning, - "error": messagebox.showerror, - "info": messagebox.showinfo, - "question": messagebox.askquestion, - "okcancel": messagebox.askokcancel, - "retrycancel": messagebox.askretrycancel, - "yesno": messagebox.askyesno, - "yesnocancel": messagebox.askyesnocancel - } + def show_message_box(_message, _title="Message", _level="info"): + with tk_context(): + message_type = { + "warning": messagebox.showwarning, + "error": messagebox.showerror, + "info": messagebox.showinfo, + "question": messagebox.askquestion, + "okcancel": messagebox.askokcancel, + "retrycancel": messagebox.askretrycancel, + "yesno": messagebox.askyesno, + "yesnocancel": messagebox.askyesnocancel + } - if _level in message_type: - message_type[_level](_title, _message) - else: - messagebox.showinfo(_title, _message) - - root.destroy() + if _level in message_type: + message_type[_level](_title, _message) + else: + messagebox.showinfo(_title, _message) if __name__ == '__main__': - mode = sys.argv[1] + try: + mode = sys.argv[1] - if mode == 'file_dialog': - starting_dir = sys.argv[2] if len(sys.argv) > 2 else None - starting_file = sys.argv[3] if len(sys.argv) > 3 else None - file_class = sys.argv[2] if len(sys.argv) > 2 else None - file_path = open_file_dialog(starting_dir, starting_file, file_class) - print(file_path) + if mode == 'file_dialog': + starting_dir = sys.argv[2] if len(sys.argv) > 2 else None + starting_file = sys.argv[3] if len(sys.argv) > 3 else None + file_class = sys.argv[4] if len(sys.argv) > 4 else None # Update this to sys.argv[4] + gui = TkGui() + file_path = gui.open_file_dialog(starting_dir, starting_file, file_class) + print(file_path) # Make sure to print the result - elif mode == 'msgbox': - message = sys.argv[2] - title = sys.argv[3] if len(sys.argv) > 3 else "" - show_message_box(message, title) + elif mode == 'msgbox': + message = sys.argv[2] + title = sys.argv[3] if len(sys.argv) > 3 else "" + gui = TkGui() + gui.show_message_box(message, title) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index 1a0edf9..e2bd428 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -81,8 +81,7 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -97,8 +96,7 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index e8321d8..9e9c196 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper +from .common_gui_functions import get_file_path, get_saveasfile_path PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -96,8 +96,7 @@ def gradio_resize_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[model, lora_ext, lora_ext_name], outputs=model, show_progress=False, diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index be127b3..c8a2fe6 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -87,8 +87,7 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -103,8 +102,7 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py index 4acf101..52626dd 100644 --- a/library/verify_lora_gui.py +++ b/library/verify_lora_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, ) PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -68,8 +68,7 @@ def gradio_verify_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_model, lora_ext, lora_ext_name], outputs=lora_model, show_progress=False, diff --git a/library/wd14_caption_gui.py b/library/wd14_caption_gui.py index a9f7742..18103ef 100644 --- a/library/wd14_caption_gui.py +++ b/library/wd14_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path +from .common_gui_functions import get_folder_path def replace_underscore_with_space(folder_path, file_extension): diff --git a/lora_gui.py b/lora_gui.py index 8990458..03b22d1 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -12,7 +12,7 @@ import subprocess import gradio as gr -from library.common_gui import ( +from library.common_gui_functions import ( get_folder_path, remove_doublequote, get_file_path, @@ -28,7 +28,7 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, + check_if_model_exist, show_message_box, ) from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dreambooth_folder_creation_gui import ( @@ -254,7 +254,7 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path_gradio_wrapper(file_path) + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file diff --git a/requirements_macos.txt b/requirements_macos.txt index 4ee4eec..c8bee08 100644 --- a/requirements_macos.txt +++ b/requirements_macos.txt @@ -27,6 +27,6 @@ huggingface-hub==0.12.0; sys_platform != 'darwin' huggingface-hub==0.13.0; sys_platform == 'darwin' tensorflow==2.10.1; sys_platform != 'darwin' # For locon support -lycoris_lora==0.1.2 +lycoris_lora==0.1.4 # for kohya_ss library . \ No newline at end of file diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index f434494..f8aaa36 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -12,7 +12,7 @@ import subprocess import gradio as gr -from library.common_gui import ( +from library.common_gui_functions import ( get_folder_path, remove_doublequote, get_file_path, @@ -28,7 +28,7 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, get_file_path_gradio_wrapper, + check_if_model_exist, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -240,9 +240,9 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path_gradio_wrapper(file_path) + file_path = get_file_path(file_path) - if not file_path == '' and not file_path == None: + if not file_path == '' and file_path is not None: # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) @@ -673,7 +673,7 @@ def ti_tab( ) weights_file_input = gr.Button('๐Ÿ“‚', elem_id='open_folder_small') weights_file_input.click( - lambda *args, **kwargs: get_file_path_gradio_wrapper, + lambda *args, **kwargs: get_file_path(*args), outputs=weights, show_progress=False, ) diff --git a/train_db - Copy.py b/train_db - Copy.py index f441d5d..22ab910 100644 --- a/train_db - Copy.py +++ b/train_db - Copy.py @@ -1,29 +1,27 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc -import time import argparse +import gc import itertools import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight + def train(args): train_util.verify_training_args(args) diff --git a/train_db.py b/train_db.py index b3eead9..37fa38e 100644 --- a/train_db.py +++ b/train_db.py @@ -1,28 +1,25 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc -import time import argparse +import gc import itertools import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight diff --git a/train_network - Copy.py b/train_network - Copy.py index 20ad2c4..a80deab 100644 --- a/train_network - Copy.py +++ b/train_network - Copy.py @@ -1,31 +1,30 @@ -from torch.nn.parallel import DistributedDataParallel as DDP -import importlib import argparse import gc +import importlib +import json import math import os import random import time -import json -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed from diffusers import DDPMScheduler +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -from library.train_util import ( - DreamBoothDataset, -) -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight +from library.train_util import ( + DreamBoothDataset, +) # TODO ไป–ใฎใ‚นใ‚ฏใƒชใƒ—ใƒˆใจๅ…ฑ้€šๅŒ–ใ™ใ‚‹ diff --git a/train_network.py b/train_network.py index 423649e..6def98d 100644 --- a/train_network.py +++ b/train_network.py @@ -1,31 +1,30 @@ -from torch.nn.parallel import DistributedDataParallel as DDP -import importlib import argparse import gc +import importlib +import json import math import os import random import time -import json -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed from diffusers import DDPMScheduler +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -from library.train_util import ( - DreamBoothDataset, -) -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight +from library.train_util import ( + DreamBoothDataset, +) # TODO ไป–ใฎใ‚นใ‚ฏใƒชใƒ—ใƒˆใจๅ…ฑ้€šๅŒ–ใ™ใ‚‹ diff --git a/train_textual_inversion - Copy.py b/train_textual_inversion - Copy.py index 681bc62..f359952 100644 --- a/train_textual_inversion - Copy.py +++ b/train_textual_inversion - Copy.py @@ -1,24 +1,21 @@ -import importlib import argparse import gc import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [ diff --git a/train_textual_inversion.py b/train_textual_inversion.py index f279370..b6f56a4 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,24 +1,21 @@ -import importlib import argparse import gc import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [ From b6d3f10da787f7ea4bb6717743fff8f23ef573c8 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Sat, 1 Apr 2023 17:33:41 -0700 Subject: [PATCH 4/4] WIP File Dialog Behavior --- dreambooth_gui.py | 402 ++++++++++++++++---------------- library/common_gui_functions.py | 32 +-- library/gui_subprocesses.py | 10 +- 3 files changed, 217 insertions(+), 227 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index e54704f..c2185c1 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -50,68 +50,68 @@ document_symbol = '\U0001F4C4' # ๐Ÿ“„ def save_configuration( - save_as, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -125,12 +125,12 @@ def save_configuration( file_path = get_saveasfile_path(file_path) else: print('Save...') - if file_path == None or file_path == '': + if file_path is None or file_path == '': file_path = get_saveasfile_path(file_path) # print(file_path) - if file_path == None or file_path == '': + if file_path is None or file_path == '': return original_file_path # In case a file_path was provided and the user decide to cancel the open action # Return the values of the variables as a dictionary @@ -159,69 +159,73 @@ def save_configuration( def open_configuration( - ask_for_file, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): + print("open_configuration called") + print(f"locals length: {len(locals())}") + print(f"locals: {locals()}") + # Get list of function parameters and values parameters = list(locals().items()) @@ -229,9 +233,12 @@ def open_configuration( original_file_path = file_path - if ask_for_file: + if ask_for_file and file_path is not None: print(f"File path: {file_path}") - file_path = get_file_path(file_path, filedialog_type="json") + file_path, canceled = get_file_path(file_path=file_path, filedialog_type="json") + + if canceled: + return (None,) + (None,) * (len(parameters) - 2) if not file_path == '' and file_path is not None: with open(file_path, 'r') as f: @@ -252,70 +259,72 @@ def open_configuration( # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) + # Print the number of returned values + print(f"Returning: {values}") return tuple(values) def train_model( - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training_pct, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, # Keep this. Yes, it is unused here but required given the common list used - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training_pct, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, # Keep this. Yes, it is unused here but required given the common list used + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): if pretrained_model_name_or_path == '': show_message_box('Source model information is missing') @@ -346,7 +355,7 @@ def train_model( f for f in os.listdir(train_data_dir) if os.path.isdir(os.path.join(train_data_dir, f)) - and not f.startswith('.') + and not f.startswith('.') ] # Check if subfolders are present. If not let the user know and return @@ -378,11 +387,11 @@ def train_model( [ f for f, lower_f in ( - (file, file.lower()) - for file in os.listdir( - os.path.join(train_data_dir, folder) - ) - ) + (file, file.lower()) + for file in os.listdir( + os.path.join(train_data_dir, folder) + ) + ) if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) ] ) @@ -846,12 +855,13 @@ def dreambooth_tab( ) button_load_config.click( - lambda *args, **kwargs: open_configuration(*args, **kwargs), + lambda *args, **kwargs: (print("Lambda called"), open_configuration(*args, **kwargs)), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) - + # Print the number of expected outputs + print(f"Number of expected outputs: {len([config_file_name] + settings_list)}") button_save_config.click( save_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, diff --git a/library/common_gui_functions.py b/library/common_gui_functions.py index 77f4923..73eb18d 100644 --- a/library/common_gui_functions.py +++ b/library/common_gui_functions.py @@ -151,26 +151,8 @@ def update_my_data(my_data): # # If no extension files were found, return False # return False -# def get_file_path_gradio_wrapper(file_path, filedialog_type="all"): -# file_extension = os.path.splitext(file_path)[-1].lower() -# -# filetype_filters = { -# 'db': ['.db'], -# 'json': ['.json'], -# 'lora': ['.pt', '.ckpt', '.safetensors'], -# } -# -# # Find the appropriate filedialog_type based on the file extension -# filedialog_type = 'all' -# for key, extensions in filetype_filters.items(): -# if file_extension in extensions: -# filedialog_type = key -# break -# -# return get_file_path(file_path, filedialog_type) - -def get_file_path(file_path='', filedialog_type="lora"): +def get_file_path(file_path, initial_dir=None, initial_file=None, filedialog_type="lora"): file_extension = os.path.splitext(file_path)[-1].lower() # Find the appropriate filedialog_type based on the file extension @@ -181,16 +163,10 @@ def get_file_path(file_path='', filedialog_type="lora"): current_file_path = file_path - print(f"File type: {filedialog_type}") initial_dir, initial_file = os.path.split(file_path) - file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type) - - # If no file is selected, use the current file path - if not file_path: - file_path = current_file_path - current_file_path = file_path - - return file_path + result = open_file_dialog(initial_dir=initial_dir, initial_file=initial_file, file_types=filedialog_type) + file_path, canceled = result[:2] + return file_path, canceled def get_any_file_path(file_path=''): diff --git a/library/gui_subprocesses.py b/library/gui_subprocesses.py index 2cbdaf2..2e45b8e 100644 --- a/library/gui_subprocesses.py +++ b/library/gui_subprocesses.py @@ -13,7 +13,6 @@ class TkGui: self.file_types = None def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"): - print(f"File types: {self.file_types}") with tk_context(): self.file_types = file_types if self.file_types in CommonUtilities.file_filters: @@ -22,9 +21,14 @@ class TkGui: filters = CommonUtilities.file_filters["all"] if self.file_types == "directory": - return filedialog.askdirectory(initialdir=initial_dir) + result = filedialog.askdirectory(initialdir=initial_dir) else: - return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) + result = filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) + + # Return a tuple (file_path, canceled) + # file_path: the selected file path or an empty string if no file is selected + # canceled: True if the user pressed the cancel button, False otherwise + return result, result == "" def save_file_dialog(self, initial_dir, initial_file, file_types="all"): self.file_types = file_types