From eef5becab8d3a2d76b05fc4533da8f59b2168e57 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 14:39:10 -0700 Subject: [PATCH] All sorts of broken, but I need to commit this for now so I don't lose it. WIP: Using some OOP to reduce imports and centralize some code. No need to remake Tk Windows everywhere for example. Renamed common_gui to common_gui_functions.py to make some of the new code separation more obvious. --- dreambooth_gui.py | 15 +- fine_tune.py | 10 +- finetune_gui.py | 17 +-- kohya_gui.py | 13 +- library/basic_caption_gui.py | 2 +- library/blip_caption_gui.py | 2 +- ...{common_gui.py => common_gui_functions.py} | 53 +++++--- library/common_utilities.py | 32 +++-- library/{config_util.py => config_ml_util.py} | 10 +- library/convert_model_gui.py | 5 +- library/dataset_balancing_gui.py | 2 +- library/dreambooth_folder_creation_gui.py | 2 +- library/extract_lora_gui.py | 9 +- library/extract_lycoris_locon_gui.py | 9 +- library/git_caption_gui.py | 2 +- library/gui_subprocesses.py | 128 +++++++++--------- library/merge_lora_gui.py | 10 +- library/resize_lora_gui.py | 5 +- library/svd_merge_lora_gui.py | 10 +- library/verify_lora_gui.py | 7 +- library/wd14_caption_gui.py | 2 +- lora_gui.py | 6 +- requirements_macos.txt | 2 +- textual_inversion_gui.py | 10 +- train_db - Copy.py | 16 +-- train_db.py | 13 +- train_network - Copy.py | 23 ++-- train_network.py | 23 ++-- train_textual_inversion - Copy.py | 11 +- train_textual_inversion.py | 11 +- 30 files changed, 234 insertions(+), 226 deletions(-) rename library/{common_gui.py => common_gui_functions.py} (96%) rename library/{config_util.py => config_ml_util.py} (100%) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 4af433e..e54704f 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -9,10 +9,11 @@ import math import os import pathlib import subprocess +import sys import gradio as gr -from library.common_gui import ( +from library.common_gui_functions import ( get_folder_path, remove_doublequote, get_file_path, @@ -28,9 +29,9 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, + check_if_model_exist, show_message_box, ) -from library.common_utilities import is_valid_config +from library.common_utilities import CommonUtilities from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) @@ -230,12 +231,12 @@ def open_configuration( if ask_for_file: print(f"File path: {file_path}") - file_path = get_file_path_gradio_wrapper(file_path) + file_path = get_file_path(file_path, filedialog_type="json") if not file_path == '' and file_path is not None: with open(file_path, 'r') as f: my_data = json.load(f) - if is_valid_config(my_data): + if CommonUtilities.is_valid_config(my_data): print('Loading config...') my_data = update_my_data(my_data) else: @@ -838,14 +839,14 @@ def dreambooth_tab( ] button_open_config.click( - lambda *args, **kwargs: open_configuration(*args), + lambda *_args, **kwargs: open_configuration(*_args, **kwargs), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - lambda *args, **kwargs: open_configuration(*args), + lambda *args, **kwargs: open_configuration(*args, **kwargs), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, diff --git a/fine_tune.py b/fine_tune.py index 637a729..ad298fd 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -5,22 +5,20 @@ import argparse import gc import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight diff --git a/finetune_gui.py b/finetune_gui.py index 18c77f9..4c6e5ce 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -1,17 +1,18 @@ -import gradio as gr +import argparse import json import math import os -import subprocess import pathlib -import argparse -from library.common_gui import ( +import subprocess + +import gradio as gr + +from library.common_gui_functions import ( get_folder_path, get_file_path, get_saveasfile_path, save_inference_file, gradio_advanced_training, - run_cmd_advanced_training, gradio_training, run_cmd_advanced_training, gradio_config, @@ -20,15 +21,15 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, get_file_path_gradio_wrapper, + check_if_model_exist, ) +from library.sampler_gui import sample_gradio_config, run_cmd_sample from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) from library.utilities import utilities_tab -from library.sampler_gui import sample_gradio_config, run_cmd_sample folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ @@ -231,7 +232,7 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path_gradio_wrapper(file_path) + file_path = get_file_path(file_path) if not file_path == '' and file_path is not None: # load variables from JSON file diff --git a/kohya_gui.py b/kohya_gui.py index f8e0d8c..732cd78 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -1,17 +1,18 @@ -import gradio as gr -import os import argparse +import os +from pathlib import Path + +import gradio as gr + from dreambooth_gui import dreambooth_tab from finetune_gui import finetune_tab -from textual_inversion_gui import ti_tab -from library.utilities import utilities_tab from library.extract_lora_gui import gradio_extract_lora_tab from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab from library.merge_lora_gui import gradio_merge_lora_tab from library.resize_lora_gui import gradio_resize_lora_tab +from library.utilities import utilities_tab from lora_gui import lora_tab - - +from textual_inversion_gui import ti_tab def UI(**kwargs): css = '' diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index d36dae1..0fa7ba6 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, add_pre_postfix, find_replace +from .common_gui_functions import get_folder_path, add_pre_postfix, find_replace def caption_images( diff --git a/library/blip_caption_gui.py b/library/blip_caption_gui.py index 35fd513..ffe087b 100644 --- a/library/blip_caption_gui.py +++ b/library/blip_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, add_pre_postfix +from .common_gui_functions import get_folder_path, add_pre_postfix PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' diff --git a/library/common_gui.py b/library/common_gui_functions.py similarity index 96% rename from library/common_gui.py rename to library/common_gui_functions.py index 4f7fb93..77f4923 100644 --- a/library/common_gui.py +++ b/library/common_gui_functions.py @@ -1,12 +1,14 @@ import os import shutil import subprocess +from contextlib import contextmanager +import tkinter as tk from tkinter import filedialog, Tk import easygui import gradio as gr -from library.gui_subprocesses import save_file_dialog +from library.common_utilities import CommonUtilities folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ @@ -37,6 +39,16 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID'] +@contextmanager +def tk_context(): + root = tk.Tk() + root.withdraw() + try: + yield root + finally: + root.destroy() + + def open_file_dialog(initial_dir, initial_file, file_types="all"): current_directory = os.path.dirname(os.path.abspath(__file__)) @@ -89,6 +101,7 @@ def check_if_model_exist(output_name, output_dir, save_model_as): return False + def update_my_data(my_data): # Update the optimizer based on the use_8bit_adam flag use_8bit_adam = my_data.get('use_8bit_adam', False) @@ -138,28 +151,37 @@ def update_my_data(my_data): # # If no extension files were found, return False # return False -def get_file_path_gradio_wrapper(file_path, filedialog_type="all"): +# def get_file_path_gradio_wrapper(file_path, filedialog_type="all"): +# file_extension = os.path.splitext(file_path)[-1].lower() +# +# filetype_filters = { +# 'db': ['.db'], +# 'json': ['.json'], +# 'lora': ['.pt', '.ckpt', '.safetensors'], +# } +# +# # Find the appropriate filedialog_type based on the file extension +# filedialog_type = 'all' +# for key, extensions in filetype_filters.items(): +# if file_extension in extensions: +# filedialog_type = key +# break +# +# return get_file_path(file_path, filedialog_type) + + +def get_file_path(file_path='', filedialog_type="lora"): file_extension = os.path.splitext(file_path)[-1].lower() - filetype_filters = { - 'db': ['.db'], - 'json': ['.json'], - 'lora': ['.pt', '.ckpt', '.safetensors'], - } - # Find the appropriate filedialog_type based on the file extension - filedialog_type = 'all' - for key, extensions in filetype_filters.items(): + for key, extensions in CommonUtilities.file_filters.items(): if file_extension in extensions: filedialog_type = key break - return get_file_path(file_path, filedialog_type) - - -def get_file_path(file_path='', filedialog_type="lora"): current_file_path = file_path + print(f"File type: {filedialog_type}") initial_dir, initial_file = os.path.split(file_path) file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type) @@ -171,7 +193,6 @@ def get_file_path(file_path='', filedialog_type="lora"): return file_path - def get_any_file_path(file_path=''): current_file_path = file_path # print(f'current file path: {current_file_path}') @@ -823,7 +844,7 @@ def gradio_advanced_training(): xformers = gr.Checkbox(label='Use xformers', value=True) color_aug = gr.Checkbox(label='Color augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False) - min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1) + min_snr_gamma = gr.Slider(label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1) with gr.Row(): bucket_no_upscale = gr.Checkbox( label="Don't upscale bucket resolution", value=True diff --git a/library/common_utilities.py b/library/common_utilities.py index ea1979c..737828b 100644 --- a/library/common_utilities.py +++ b/library/common_utilities.py @@ -1,14 +1,24 @@ -def is_valid_config(data): - # Check if the data is a dictionary - if not isinstance(data, dict): - return False +class CommonUtilities: + file_filters = { + "all": [("All files", "*.*")], + "video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")], + "images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")], + "json": [("JSON files", "*.json")], + "lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], + "directory": [], + } - # Add checks for expected keys and valid values - # For example, check if 'use_8bit_adam' is a boolean - if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): - return False + def is_valid_config(self, data): + # Check if the data is a dictionary + if not isinstance(data, dict): + return False - # Add more checks for other keys as needed + # Add checks for expected keys and valid values + # For example, check if 'use_8bit_adam' is a boolean + if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): + return False - # If all checks pass, return True - return True + # Add more checks for other keys as needed + + # If all checks pass, return True + return True diff --git a/library/config_util.py b/library/config_ml_util.py similarity index 100% rename from library/config_util.py rename to library/config_ml_util.py index 97bbb4a..af35896 100644 --- a/library/config_util.py +++ b/library/config_ml_util.py @@ -1,13 +1,13 @@ import argparse +import functools +import json +import random from dataclasses import ( asdict, dataclass, ) -import functools -import random -from textwrap import dedent, indent -import json from pathlib import Path +from textwrap import dedent, indent # from toolz import curry from typing import ( List, @@ -19,6 +19,7 @@ from typing import ( import toml import voluptuous +from transformers import CLIPTokenizer from voluptuous import ( Any, ExactSequence, @@ -27,7 +28,6 @@ from voluptuous import ( Required, Schema, ) -from transformers import CLIPTokenizer from . import train_util from .train_util import ( diff --git a/library/convert_model_gui.py b/library/convert_model_gui.py index 2450eb5..580833d 100644 --- a/library/convert_model_gui.py +++ b/library/convert_model_gui.py @@ -4,7 +4,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, get_file_path, get_file_path_gradio_wrapper +from .common_gui_functions import get_folder_path, get_file_path folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ refresh_symbol = '\U0001f504' # ๐Ÿ”„ @@ -180,8 +180,7 @@ def gradio_convert_model_tab(): document_symbol, elem_id='open_folder_small' ) button_source_model_file.click( - lambda input1, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.normpath(input1)), + lambda *args, **kwargs: get_file_path(*args), inputs=[source_model_input], outputs=source_model_input, show_progress=False, diff --git a/library/dataset_balancing_gui.py b/library/dataset_balancing_gui.py index 34292f6..f21319b 100644 --- a/library/dataset_balancing_gui.py +++ b/library/dataset_balancing_gui.py @@ -4,7 +4,7 @@ import re import gradio as gr from easygui import boolbox -from .common_gui import get_folder_path +from .common_gui_functions import get_folder_path # def select_folder(): diff --git a/library/dreambooth_folder_creation_gui.py b/library/dreambooth_folder_creation_gui.py index 53ec71a..e554930 100644 --- a/library/dreambooth_folder_creation_gui.py +++ b/library/dreambooth_folder_creation_gui.py @@ -3,7 +3,7 @@ import shutil import gradio as gr -from .common_gui import get_folder_path +from .common_gui_functions import get_folder_path def copy_info_to_Folders_tab(training_folder): diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index dbad3d1..b54e4ff 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -90,8 +90,7 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_model_tuned_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[model_tuned, model_ext, model_ext_name], outputs=model_tuned, show_progress=False, @@ -107,7 +106,7 @@ def gradio_extract_lora_tab(): ) button_model_org_file.click( lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[model_org, model_ext, model_ext_name], outputs=model_org, show_progress=False, diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py index 491e8ac..e8a620b 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/library/extract_lycoris_locon_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -137,7 +137,7 @@ def gradio_extract_lycoris_locon_tab(): ) button_db_model_file.click( lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[db_model, model_ext, model_ext_name], outputs=db_model, show_progress=False, @@ -152,8 +152,7 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_base_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[base_model, model_ext, model_ext_name], outputs=base_model, show_progress=False, diff --git a/library/git_caption_gui.py b/library/git_caption_gui.py index 4ef87b4..a4cc6d0 100644 --- a/library/git_caption_gui.py +++ b/library/git_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path, add_pre_postfix +from .common_gui_functions import get_folder_path, add_pre_postfix PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' diff --git a/library/gui_subprocesses.py b/library/gui_subprocesses.py index 87e65fc..2cbdaf2 100644 --- a/library/gui_subprocesses.py +++ b/library/gui_subprocesses.py @@ -1,84 +1,82 @@ +import os +import pathlib import sys import tkinter as tk from tkinter import filedialog, messagebox - -def open_file_dialog(initial_dir=None, initial_file=None, file_types="all"): - file_type_filters = { - "all": [("All files", "*.*")], - "video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")], - "images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")], - "json": [("JSON files", "*.json")], - "lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], - "directory": [], - } - - if file_types in file_type_filters: - filters = file_type_filters[file_types] - else: - filters = file_type_filters["all"] - - if file_types == "directory": - return filedialog.askdirectory(initialdir=initial_dir) - else: - return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) +from library.common_gui_functions import tk_context +from library.common_utilities import CommonUtilities -def save_file_dialog(initial_dir, initial_file, files_type="all"): - root = tk.Tk() - root.withdraw() +class TkGui: + def __init__(self): + self.file_types = None - filetypes_switch = { - "all": [("All files", "*.*")], - "video": [("Video files", "*.mp4;*.avi;*.mkv;*.webm;*.flv;*.mov;*.wmv")], - "images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff;*.ico")], - "json": [("JSON files", "*.json")], - "lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], - } + def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"): + print(f"File types: {self.file_types}") + with tk_context(): + self.file_types = file_types + if self.file_types in CommonUtilities.file_filters: + filters = CommonUtilities.file_filters[self.file_types] + else: + filters = CommonUtilities.file_filters["all"] - filetypes = filetypes_switch.get(files_type, filetypes_switch["all"]) - save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filetypes, - defaultextension=filetypes) + if self.file_types == "directory": + return filedialog.askdirectory(initialdir=initial_dir) + else: + return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) - root.destroy() + def save_file_dialog(self, initial_dir, initial_file, file_types="all"): + self.file_types = file_types - return save_file_path + # Use the tk_context function with the 'with' statement + with tk_context(): + if self.file_types in CommonUtilities.file_filters: + filters = CommonUtilities.file_filters[self.file_types] + else: + filters = CommonUtilities.file_filters["all"] + save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file, + filetypes=filters, defaultextension=".safetensors") -def show_message_box(_message, _title="Message", _level="info"): - root = tk.Tk() - root.withdraw() + return save_file_path - message_type = { - "warning": messagebox.showwarning, - "error": messagebox.showerror, - "info": messagebox.showinfo, - "question": messagebox.askquestion, - "okcancel": messagebox.askokcancel, - "retrycancel": messagebox.askretrycancel, - "yesno": messagebox.askyesno, - "yesnocancel": messagebox.askyesnocancel - } + def show_message_box(_message, _title="Message", _level="info"): + with tk_context(): + message_type = { + "warning": messagebox.showwarning, + "error": messagebox.showerror, + "info": messagebox.showinfo, + "question": messagebox.askquestion, + "okcancel": messagebox.askokcancel, + "retrycancel": messagebox.askretrycancel, + "yesno": messagebox.askyesno, + "yesnocancel": messagebox.askyesnocancel + } - if _level in message_type: - message_type[_level](_title, _message) - else: - messagebox.showinfo(_title, _message) - - root.destroy() + if _level in message_type: + message_type[_level](_title, _message) + else: + messagebox.showinfo(_title, _message) if __name__ == '__main__': - mode = sys.argv[1] + try: + mode = sys.argv[1] - if mode == 'file_dialog': - starting_dir = sys.argv[2] if len(sys.argv) > 2 else None - starting_file = sys.argv[3] if len(sys.argv) > 3 else None - file_class = sys.argv[2] if len(sys.argv) > 2 else None - file_path = open_file_dialog(starting_dir, starting_file, file_class) - print(file_path) + if mode == 'file_dialog': + starting_dir = sys.argv[2] if len(sys.argv) > 2 else None + starting_file = sys.argv[3] if len(sys.argv) > 3 else None + file_class = sys.argv[4] if len(sys.argv) > 4 else None # Update this to sys.argv[4] + gui = TkGui() + file_path = gui.open_file_dialog(starting_dir, starting_file, file_class) + print(file_path) # Make sure to print the result - elif mode == 'msgbox': - message = sys.argv[2] - title = sys.argv[3] if len(sys.argv) > 3 else "" - show_message_box(message, title) + elif mode == 'msgbox': + message = sys.argv[2] + title = sys.argv[3] if len(sys.argv) > 3 else "" + gui = TkGui() + gui.show_message_box(message, title) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index 1a0edf9..e2bd428 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -81,8 +81,7 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -97,8 +96,7 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index e8321d8..9e9c196 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper +from .common_gui_functions import get_file_path, get_saveasfile_path PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -96,8 +96,7 @@ def gradio_resize_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[model, lora_ext, lora_ext_name], outputs=model, show_progress=False, diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index be127b3..c8a2fe6 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, get_saveasfile_path, ) folder_symbol = '\U0001f4c2' # ๐Ÿ“‚ @@ -87,8 +87,7 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -103,8 +102,7 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py index 4acf101..52626dd 100644 --- a/library/verify_lora_gui.py +++ b/library/verify_lora_gui.py @@ -3,8 +3,8 @@ import subprocess import gradio as gr -from .common_gui import ( - get_file_path, get_file_path_gradio_wrapper, +from .common_gui_functions import ( + get_file_path, ) PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -68,8 +68,7 @@ def gradio_verify_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), + lambda *args, **kwargs: get_file_path(*args), inputs=[lora_model, lora_ext, lora_ext_name], outputs=lora_model, show_progress=False, diff --git a/library/wd14_caption_gui.py b/library/wd14_caption_gui.py index a9f7742..18103ef 100644 --- a/library/wd14_caption_gui.py +++ b/library/wd14_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import gradio as gr -from .common_gui import get_folder_path +from .common_gui_functions import get_folder_path def replace_underscore_with_space(folder_path, file_extension): diff --git a/lora_gui.py b/lora_gui.py index 8990458..03b22d1 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -12,7 +12,7 @@ import subprocess import gradio as gr -from library.common_gui import ( +from library.common_gui_functions import ( get_folder_path, remove_doublequote, get_file_path, @@ -28,7 +28,7 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, + check_if_model_exist, show_message_box, ) from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dreambooth_folder_creation_gui import ( @@ -254,7 +254,7 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path_gradio_wrapper(file_path) + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file diff --git a/requirements_macos.txt b/requirements_macos.txt index 4ee4eec..c8bee08 100644 --- a/requirements_macos.txt +++ b/requirements_macos.txt @@ -27,6 +27,6 @@ huggingface-hub==0.12.0; sys_platform != 'darwin' huggingface-hub==0.13.0; sys_platform == 'darwin' tensorflow==2.10.1; sys_platform != 'darwin' # For locon support -lycoris_lora==0.1.2 +lycoris_lora==0.1.4 # for kohya_ss library . \ No newline at end of file diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index f434494..f8aaa36 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -12,7 +12,7 @@ import subprocess import gradio as gr -from library.common_gui import ( +from library.common_gui_functions import ( get_folder_path, remove_doublequote, get_file_path, @@ -28,7 +28,7 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, get_file_path_gradio_wrapper, + check_if_model_exist, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -240,9 +240,9 @@ def open_configuration( original_file_path = file_path if ask_for_file: - file_path = get_file_path_gradio_wrapper(file_path) + file_path = get_file_path(file_path) - if not file_path == '' and not file_path == None: + if not file_path == '' and file_path is not None: # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) @@ -673,7 +673,7 @@ def ti_tab( ) weights_file_input = gr.Button('๐Ÿ“‚', elem_id='open_folder_small') weights_file_input.click( - lambda *args, **kwargs: get_file_path_gradio_wrapper, + lambda *args, **kwargs: get_file_path(*args), outputs=weights, show_progress=False, ) diff --git a/train_db - Copy.py b/train_db - Copy.py index f441d5d..22ab910 100644 --- a/train_db - Copy.py +++ b/train_db - Copy.py @@ -1,29 +1,27 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc -import time import argparse +import gc import itertools import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight + def train(args): train_util.verify_training_args(args) diff --git a/train_db.py b/train_db.py index b3eead9..37fa38e 100644 --- a/train_db.py +++ b/train_db.py @@ -1,28 +1,25 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc -import time import argparse +import gc import itertools import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight diff --git a/train_network - Copy.py b/train_network - Copy.py index 20ad2c4..a80deab 100644 --- a/train_network - Copy.py +++ b/train_network - Copy.py @@ -1,31 +1,30 @@ -from torch.nn.parallel import DistributedDataParallel as DDP -import importlib import argparse import gc +import importlib +import json import math import os import random import time -import json -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed from diffusers import DDPMScheduler +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -from library.train_util import ( - DreamBoothDataset, -) -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight +from library.train_util import ( + DreamBoothDataset, +) # TODO ไป–ใฎใ‚นใ‚ฏใƒชใƒ—ใƒˆใจๅ…ฑ้€šๅŒ–ใ™ใ‚‹ diff --git a/train_network.py b/train_network.py index 423649e..6def98d 100644 --- a/train_network.py +++ b/train_network.py @@ -1,31 +1,30 @@ -from torch.nn.parallel import DistributedDataParallel as DDP -import importlib import argparse import gc +import importlib +import json import math import os import random import time -import json -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed from diffusers import DDPMScheduler +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -from library.train_util import ( - DreamBoothDataset, -) -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight +from library.train_util import ( + DreamBoothDataset, +) # TODO ไป–ใฎใ‚นใ‚ฏใƒชใƒ—ใƒˆใจๅ…ฑ้€šๅŒ–ใ™ใ‚‹ diff --git a/train_textual_inversion - Copy.py b/train_textual_inversion - Copy.py index 681bc62..f359952 100644 --- a/train_textual_inversion - Copy.py +++ b/train_textual_inversion - Copy.py @@ -1,24 +1,21 @@ -import importlib import argparse import gc import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [ diff --git a/train_textual_inversion.py b/train_textual_inversion.py index f279370..b6f56a4 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,24 +1,21 @@ -import importlib import argparse import gc import math import os -import toml from multiprocessing import Value -from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from tqdm import tqdm +import library.config_ml_util as config_util +import library.custom_train_functions as custom_train_functions import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( +from library.config_ml_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [