All sorts of broken, but I need to commit this for now so I don't lose it. WIP: Using some OOP to reduce imports and centralize some code. No need to remake Tk Windows everywhere for example. Renamed common_gui to common_gui_functions.py to make some of the new code separation more obvious.

This commit is contained in:
JSTayco 2023-03-31 14:39:10 -07:00
parent e5b83df675
commit eef5becab8
30 changed files with 234 additions and 226 deletions

View File

@ -9,10 +9,11 @@ import math
import os import os
import pathlib import pathlib
import subprocess import subprocess
import sys
import gradio as gr import gradio as gr
from library.common_gui import ( from library.common_gui_functions import (
get_folder_path, get_folder_path,
remove_doublequote, remove_doublequote,
get_file_path, get_file_path,
@ -28,9 +29,9 @@ from library.common_gui import (
gradio_source_model, gradio_source_model,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, check_if_model_exist, show_message_box,
) )
from library.common_utilities import is_valid_config from library.common_utilities import CommonUtilities
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
) )
@ -230,12 +231,12 @@ def open_configuration(
if ask_for_file: if ask_for_file:
print(f"File path: {file_path}") print(f"File path: {file_path}")
file_path = get_file_path_gradio_wrapper(file_path) file_path = get_file_path(file_path, filedialog_type="json")
if not file_path == '' and file_path is not None: if not file_path == '' and file_path is not None:
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
my_data = json.load(f) my_data = json.load(f)
if is_valid_config(my_data): if CommonUtilities.is_valid_config(my_data):
print('Loading config...') print('Loading config...')
my_data = update_my_data(my_data) my_data = update_my_data(my_data)
else: else:
@ -838,14 +839,14 @@ def dreambooth_tab(
] ]
button_open_config.click( button_open_config.click(
lambda *args, **kwargs: open_configuration(*args), lambda *_args, **kwargs: open_configuration(*_args, **kwargs),
inputs=[dummy_db_true, config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )
button_load_config.click( button_load_config.click(
lambda *args, **kwargs: open_configuration(*args), lambda *args, **kwargs: open_configuration(*args, **kwargs),
inputs=[dummy_db_true, config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,

View File

@ -5,22 +5,20 @@ import argparse
import gc import gc
import math import math
import os import os
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util from library.config_ml_util import (
from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight from library.custom_train_functions import apply_snr_weight

View File

@ -1,17 +1,18 @@
import gradio as gr import argparse
import json import json
import math import math
import os import os
import subprocess
import pathlib import pathlib
import argparse import subprocess
from library.common_gui import (
import gradio as gr
from library.common_gui_functions import (
get_folder_path, get_folder_path,
get_file_path, get_file_path,
get_saveasfile_path, get_saveasfile_path,
save_inference_file, save_inference_file,
gradio_advanced_training, gradio_advanced_training,
run_cmd_advanced_training,
gradio_training, gradio_training,
run_cmd_advanced_training, run_cmd_advanced_training,
gradio_config, gradio_config,
@ -20,15 +21,15 @@ from library.common_gui import (
run_cmd_training, run_cmd_training,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist, get_file_path_gradio_wrapper, check_if_model_exist,
) )
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from library.tensorboard_gui import ( from library.tensorboard_gui import (
gradio_tensorboard, gradio_tensorboard,
start_tensorboard, start_tensorboard,
stop_tensorboard, stop_tensorboard,
) )
from library.utilities import utilities_tab from library.utilities import utilities_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
@ -231,7 +232,7 @@ def open_configuration(
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path_gradio_wrapper(file_path) file_path = get_file_path(file_path)
if not file_path == '' and file_path is not None: if not file_path == '' and file_path is not None:
# load variables from JSON file # load variables from JSON file

View File

@ -1,17 +1,18 @@
import gradio as gr
import os
import argparse import argparse
import os
from pathlib import Path
import gradio as gr
from dreambooth_gui import dreambooth_tab from dreambooth_gui import dreambooth_tab
from finetune_gui import finetune_tab from finetune_gui import finetune_tab
from textual_inversion_gui import ti_tab
from library.utilities import utilities_tab
from library.extract_lora_gui import gradio_extract_lora_tab from library.extract_lora_gui import gradio_extract_lora_tab
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
from library.merge_lora_gui import gradio_merge_lora_tab from library.merge_lora_gui import gradio_merge_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab from library.resize_lora_gui import gradio_resize_lora_tab
from library.utilities import utilities_tab
from lora_gui import lora_tab from lora_gui import lora_tab
from textual_inversion_gui import ti_tab
def UI(**kwargs): def UI(**kwargs):
css = '' css = ''

View File

@ -3,7 +3,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import get_folder_path, add_pre_postfix, find_replace from .common_gui_functions import get_folder_path, add_pre_postfix, find_replace
def caption_images( def caption_images(

View File

@ -3,7 +3,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import get_folder_path, add_pre_postfix from .common_gui_functions import get_folder_path, add_pre_postfix
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'

View File

@ -1,12 +1,14 @@
import os import os
import shutil import shutil
import subprocess import subprocess
from contextlib import contextmanager
import tkinter as tk
from tkinter import filedialog, Tk from tkinter import filedialog, Tk
import easygui import easygui
import gradio as gr import gradio as gr
from library.gui_subprocesses import save_file_dialog from library.common_utilities import CommonUtilities
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
@ -37,6 +39,16 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID'] FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
@contextmanager
def tk_context():
root = tk.Tk()
root.withdraw()
try:
yield root
finally:
root.destroy()
def open_file_dialog(initial_dir, initial_file, file_types="all"): def open_file_dialog(initial_dir, initial_file, file_types="all"):
current_directory = os.path.dirname(os.path.abspath(__file__)) current_directory = os.path.dirname(os.path.abspath(__file__))
@ -89,6 +101,7 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
return False return False
def update_my_data(my_data): def update_my_data(my_data):
# Update the optimizer based on the use_8bit_adam flag # Update the optimizer based on the use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False) use_8bit_adam = my_data.get('use_8bit_adam', False)
@ -138,28 +151,37 @@ def update_my_data(my_data):
# # If no extension files were found, return False # # If no extension files were found, return False
# return False # return False
def get_file_path_gradio_wrapper(file_path, filedialog_type="all"): # def get_file_path_gradio_wrapper(file_path, filedialog_type="all"):
# file_extension = os.path.splitext(file_path)[-1].lower()
#
# filetype_filters = {
# 'db': ['.db'],
# 'json': ['.json'],
# 'lora': ['.pt', '.ckpt', '.safetensors'],
# }
#
# # Find the appropriate filedialog_type based on the file extension
# filedialog_type = 'all'
# for key, extensions in filetype_filters.items():
# if file_extension in extensions:
# filedialog_type = key
# break
#
# return get_file_path(file_path, filedialog_type)
def get_file_path(file_path='', filedialog_type="lora"):
file_extension = os.path.splitext(file_path)[-1].lower() file_extension = os.path.splitext(file_path)[-1].lower()
filetype_filters = {
'db': ['.db'],
'json': ['.json'],
'lora': ['.pt', '.ckpt', '.safetensors'],
}
# Find the appropriate filedialog_type based on the file extension # Find the appropriate filedialog_type based on the file extension
filedialog_type = 'all' for key, extensions in CommonUtilities.file_filters.items():
for key, extensions in filetype_filters.items():
if file_extension in extensions: if file_extension in extensions:
filedialog_type = key filedialog_type = key
break break
return get_file_path(file_path, filedialog_type)
def get_file_path(file_path='', filedialog_type="lora"):
current_file_path = file_path current_file_path = file_path
print(f"File type: {filedialog_type}")
initial_dir, initial_file = os.path.split(file_path) initial_dir, initial_file = os.path.split(file_path)
file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type) file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type)
@ -171,7 +193,6 @@ def get_file_path(file_path='', filedialog_type="lora"):
return file_path return file_path
def get_any_file_path(file_path=''): def get_any_file_path(file_path=''):
current_file_path = file_path current_file_path = file_path
# print(f'current file path: {current_file_path}') # print(f'current file path: {current_file_path}')
@ -823,7 +844,7 @@ def gradio_advanced_training():
xformers = gr.Checkbox(label='Use xformers', value=True) xformers = gr.Checkbox(label='Use xformers', value=True)
color_aug = gr.Checkbox(label='Color augmentation', value=False) color_aug = gr.Checkbox(label='Color augmentation', value=False)
flip_aug = gr.Checkbox(label='Flip augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1) min_snr_gamma = gr.Slider(label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1)
with gr.Row(): with gr.Row():
bucket_no_upscale = gr.Checkbox( bucket_no_upscale = gr.Checkbox(
label="Don't upscale bucket resolution", value=True label="Don't upscale bucket resolution", value=True

View File

@ -1,14 +1,24 @@
def is_valid_config(data): class CommonUtilities:
# Check if the data is a dictionary file_filters = {
if not isinstance(data, dict): "all": [("All files", "*.*")],
return False "video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")],
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")],
"json": [("JSON files", "*.json")],
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
"directory": [],
}
# Add checks for expected keys and valid values def is_valid_config(self, data):
# For example, check if 'use_8bit_adam' is a boolean # Check if the data is a dictionary
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): if not isinstance(data, dict):
return False return False
# Add more checks for other keys as needed # Add checks for expected keys and valid values
# For example, check if 'use_8bit_adam' is a boolean
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
return False
# If all checks pass, return True # Add more checks for other keys as needed
return True
# If all checks pass, return True
return True

View File

@ -1,13 +1,13 @@
import argparse import argparse
import functools
import json
import random
from dataclasses import ( from dataclasses import (
asdict, asdict,
dataclass, dataclass,
) )
import functools
import random
from textwrap import dedent, indent
import json
from pathlib import Path from pathlib import Path
from textwrap import dedent, indent
# from toolz import curry # from toolz import curry
from typing import ( from typing import (
List, List,
@ -19,6 +19,7 @@ from typing import (
import toml import toml
import voluptuous import voluptuous
from transformers import CLIPTokenizer
from voluptuous import ( from voluptuous import (
Any, Any,
ExactSequence, ExactSequence,
@ -27,7 +28,6 @@ from voluptuous import (
Required, Required,
Schema, Schema,
) )
from transformers import CLIPTokenizer
from . import train_util from . import train_util
from .train_util import ( from .train_util import (

View File

@ -4,7 +4,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import get_folder_path, get_file_path, get_file_path_gradio_wrapper from .common_gui_functions import get_folder_path, get_file_path
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
@ -180,8 +180,7 @@ def gradio_convert_model_tab():
document_symbol, elem_id='open_folder_small' document_symbol, elem_id='open_folder_small'
) )
button_source_model_file.click( button_source_model_file.click(
lambda input1, *args, **kwargs: lambda *args, **kwargs: get_file_path(*args),
get_file_path_gradio_wrapper(file_path=os.path.normpath(input1)),
inputs=[source_model_input], inputs=[source_model_input],
outputs=source_model_input, outputs=source_model_input,
show_progress=False, show_progress=False,

View File

@ -4,7 +4,7 @@ import re
import gradio as gr import gradio as gr
from easygui import boolbox from easygui import boolbox
from .common_gui import get_folder_path from .common_gui_functions import get_folder_path
# def select_folder(): # def select_folder():

View File

@ -3,7 +3,7 @@ import shutil
import gradio as gr import gradio as gr
from .common_gui import get_folder_path from .common_gui_functions import get_folder_path
def copy_info_to_Folders_tab(training_folder): def copy_info_to_Folders_tab(training_folder):

View File

@ -3,8 +3,8 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui_functions import (
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, get_file_path, get_saveasfile_path,
) )
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -90,8 +90,7 @@ def gradio_extract_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_model_tuned_file.click( button_model_tuned_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda *args, **kwargs: get_file_path(*args),
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[model_tuned, model_ext, model_ext_name], inputs=[model_tuned, model_ext, model_ext_name],
outputs=model_tuned, outputs=model_tuned,
show_progress=False, show_progress=False,
@ -107,7 +106,7 @@ def gradio_extract_lora_tab():
) )
button_model_org_file.click( button_model_org_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), lambda *args, **kwargs: get_file_path(*args),
inputs=[model_org, model_ext, model_ext_name], inputs=[model_org, model_ext, model_ext_name],
outputs=model_org, outputs=model_org,
show_progress=False, show_progress=False,

View File

@ -3,8 +3,8 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui_functions import (
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, get_file_path, get_saveasfile_path,
) )
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -137,7 +137,7 @@ def gradio_extract_lycoris_locon_tab():
) )
button_db_model_file.click( button_db_model_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)), lambda *args, **kwargs: get_file_path(*args),
inputs=[db_model, model_ext, model_ext_name], inputs=[db_model, model_ext, model_ext_name],
outputs=db_model, outputs=db_model,
show_progress=False, show_progress=False,
@ -152,8 +152,7 @@ def gradio_extract_lycoris_locon_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_base_model_file.click( button_base_model_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda *args, **kwargs: get_file_path(*args),
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[base_model, model_ext, model_ext_name], inputs=[base_model, model_ext, model_ext_name],
outputs=base_model, outputs=base_model,
show_progress=False, show_progress=False,

View File

@ -3,7 +3,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import get_folder_path, add_pre_postfix from .common_gui_functions import get_folder_path, add_pre_postfix
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'

View File

@ -1,84 +1,82 @@
import os
import pathlib
import sys import sys
import tkinter as tk import tkinter as tk
from tkinter import filedialog, messagebox from tkinter import filedialog, messagebox
from library.common_gui_functions import tk_context
def open_file_dialog(initial_dir=None, initial_file=None, file_types="all"): from library.common_utilities import CommonUtilities
file_type_filters = {
"all": [("All files", "*.*")],
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")],
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")],
"json": [("JSON files", "*.json")],
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
"directory": [],
}
if file_types in file_type_filters:
filters = file_type_filters[file_types]
else:
filters = file_type_filters["all"]
if file_types == "directory":
return filedialog.askdirectory(initialdir=initial_dir)
else:
return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
def save_file_dialog(initial_dir, initial_file, files_type="all"): class TkGui:
root = tk.Tk() def __init__(self):
root.withdraw() self.file_types = None
filetypes_switch = { def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"):
"all": [("All files", "*.*")], print(f"File types: {self.file_types}")
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.webm;*.flv;*.mov;*.wmv")], with tk_context():
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff;*.ico")], self.file_types = file_types
"json": [("JSON files", "*.json")], if self.file_types in CommonUtilities.file_filters:
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], filters = CommonUtilities.file_filters[self.file_types]
} else:
filters = CommonUtilities.file_filters["all"]
filetypes = filetypes_switch.get(files_type, filetypes_switch["all"]) if self.file_types == "directory":
save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filetypes, return filedialog.askdirectory(initialdir=initial_dir)
defaultextension=filetypes) else:
return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
root.destroy() def save_file_dialog(self, initial_dir, initial_file, file_types="all"):
self.file_types = file_types
return save_file_path # Use the tk_context function with the 'with' statement
with tk_context():
if self.file_types in CommonUtilities.file_filters:
filters = CommonUtilities.file_filters[self.file_types]
else:
filters = CommonUtilities.file_filters["all"]
save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file,
filetypes=filters, defaultextension=".safetensors")
def show_message_box(_message, _title="Message", _level="info"): return save_file_path
root = tk.Tk()
root.withdraw()
message_type = { def show_message_box(_message, _title="Message", _level="info"):
"warning": messagebox.showwarning, with tk_context():
"error": messagebox.showerror, message_type = {
"info": messagebox.showinfo, "warning": messagebox.showwarning,
"question": messagebox.askquestion, "error": messagebox.showerror,
"okcancel": messagebox.askokcancel, "info": messagebox.showinfo,
"retrycancel": messagebox.askretrycancel, "question": messagebox.askquestion,
"yesno": messagebox.askyesno, "okcancel": messagebox.askokcancel,
"yesnocancel": messagebox.askyesnocancel "retrycancel": messagebox.askretrycancel,
} "yesno": messagebox.askyesno,
"yesnocancel": messagebox.askyesnocancel
}
if _level in message_type: if _level in message_type:
message_type[_level](_title, _message) message_type[_level](_title, _message)
else: else:
messagebox.showinfo(_title, _message) messagebox.showinfo(_title, _message)
root.destroy()
if __name__ == '__main__': if __name__ == '__main__':
mode = sys.argv[1] try:
mode = sys.argv[1]
if mode == 'file_dialog': if mode == 'file_dialog':
starting_dir = sys.argv[2] if len(sys.argv) > 2 else None starting_dir = sys.argv[2] if len(sys.argv) > 2 else None
starting_file = sys.argv[3] if len(sys.argv) > 3 else None starting_file = sys.argv[3] if len(sys.argv) > 3 else None
file_class = sys.argv[2] if len(sys.argv) > 2 else None file_class = sys.argv[4] if len(sys.argv) > 4 else None # Update this to sys.argv[4]
file_path = open_file_dialog(starting_dir, starting_file, file_class) gui = TkGui()
print(file_path) file_path = gui.open_file_dialog(starting_dir, starting_file, file_class)
print(file_path) # Make sure to print the result
elif mode == 'msgbox': elif mode == 'msgbox':
message = sys.argv[2] message = sys.argv[2]
title = sys.argv[3] if len(sys.argv) > 3 else "" title = sys.argv[3] if len(sys.argv) > 3 else ""
show_message_box(message, title) gui = TkGui()
gui.show_message_box(message, title)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)

View File

@ -3,8 +3,8 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui_functions import (
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, get_file_path, get_saveasfile_path,
) )
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -81,8 +81,7 @@ def gradio_merge_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_a_model_file.click( button_lora_a_model_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda *args, **kwargs: get_file_path(*args),
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_a_model, lora_ext, lora_ext_name], inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model, outputs=lora_a_model,
show_progress=False, show_progress=False,
@ -97,8 +96,7 @@ def gradio_merge_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_b_model_file.click( button_lora_b_model_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda *args, **kwargs: get_file_path(*args),
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_b_model, lora_ext, lora_ext_name], inputs=[lora_b_model, lora_ext, lora_ext_name],
outputs=lora_b_model, outputs=lora_b_model,
show_progress=False, show_progress=False,

View File

@ -3,7 +3,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper from .common_gui_functions import get_file_path, get_saveasfile_path
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -96,8 +96,7 @@ def gradio_resize_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_a_model_file.click( button_lora_a_model_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda *args, **kwargs: get_file_path(*args),
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[model, lora_ext, lora_ext_name], inputs=[model, lora_ext, lora_ext_name],
outputs=model, outputs=model,
show_progress=False, show_progress=False,

View File

@ -3,8 +3,8 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui_functions import (
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper, get_file_path, get_saveasfile_path,
) )
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -87,8 +87,7 @@ def gradio_svd_merge_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_a_model_file.click( button_lora_a_model_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda *args, **kwargs: get_file_path(*args),
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_a_model, lora_ext, lora_ext_name], inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model, outputs=lora_a_model,
show_progress=False, show_progress=False,
@ -103,8 +102,7 @@ def gradio_svd_merge_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_b_model_file.click( button_lora_b_model_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda *args, **kwargs: get_file_path(*args),
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_b_model, lora_ext, lora_ext_name], inputs=[lora_b_model, lora_ext, lora_ext_name],
outputs=lora_b_model, outputs=lora_b_model,
show_progress=False, show_progress=False,

View File

@ -3,8 +3,8 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui_functions import (
get_file_path, get_file_path_gradio_wrapper, get_file_path,
) )
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
@ -68,8 +68,7 @@ def gradio_verify_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_model_file.click( button_lora_model_file.click(
lambda input1, input2, input3, *args, **kwargs: lambda *args, **kwargs: get_file_path(*args),
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_model, lora_ext, lora_ext_name], inputs=[lora_model, lora_ext, lora_ext_name],
outputs=lora_model, outputs=lora_model,
show_progress=False, show_progress=False,

View File

@ -3,7 +3,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import get_folder_path from .common_gui_functions import get_folder_path
def replace_underscore_with_space(folder_path, file_extension): def replace_underscore_with_space(folder_path, file_extension):

View File

@ -12,7 +12,7 @@ import subprocess
import gradio as gr import gradio as gr
from library.common_gui import ( from library.common_gui_functions import (
get_folder_path, get_folder_path,
remove_doublequote, remove_doublequote,
get_file_path, get_file_path,
@ -28,7 +28,7 @@ from library.common_gui import (
run_cmd_training, run_cmd_training,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist, show_message_box, get_file_path_gradio_wrapper, check_if_model_exist, show_message_box,
) )
from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
@ -254,7 +254,7 @@ def open_configuration(
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path_gradio_wrapper(file_path) file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None: if not file_path == '' and not file_path == None:
# load variables from JSON file # load variables from JSON file

View File

@ -27,6 +27,6 @@ huggingface-hub==0.12.0; sys_platform != 'darwin'
huggingface-hub==0.13.0; sys_platform == 'darwin' huggingface-hub==0.13.0; sys_platform == 'darwin'
tensorflow==2.10.1; sys_platform != 'darwin' tensorflow==2.10.1; sys_platform != 'darwin'
# For locon support # For locon support
lycoris_lora==0.1.2 lycoris_lora==0.1.4
# for kohya_ss library # for kohya_ss library
. .

View File

@ -12,7 +12,7 @@ import subprocess
import gradio as gr import gradio as gr
from library.common_gui import ( from library.common_gui_functions import (
get_folder_path, get_folder_path,
remove_doublequote, remove_doublequote,
get_file_path, get_file_path,
@ -28,7 +28,7 @@ from library.common_gui import (
gradio_source_model, gradio_source_model,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist, get_file_path_gradio_wrapper, check_if_model_exist,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -240,9 +240,9 @@ def open_configuration(
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path_gradio_wrapper(file_path) file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None: if not file_path == '' and file_path is not None:
# load variables from JSON file # load variables from JSON file
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
my_data = json.load(f) my_data = json.load(f)
@ -673,7 +673,7 @@ def ti_tab(
) )
weights_file_input = gr.Button('📂', elem_id='open_folder_small') weights_file_input = gr.Button('📂', elem_id='open_folder_small')
weights_file_input.click( weights_file_input.click(
lambda *args, **kwargs: get_file_path_gradio_wrapper, lambda *args, **kwargs: get_file_path(*args),
outputs=weights, outputs=weights,
show_progress=False, show_progress=False,
) )

View File

@ -1,29 +1,27 @@
# DreamBooth training # DreamBooth training
# XXX dropped option: fine_tune # XXX dropped option: fine_tune
import gc
import time
import argparse import argparse
import gc
import itertools import itertools
import math import math
import os import os
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util from library.config_ml_util import (
from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight
def train(args): def train(args):
train_util.verify_training_args(args) train_util.verify_training_args(args)

View File

@ -1,28 +1,25 @@
# DreamBooth training # DreamBooth training
# XXX dropped option: fine_tune # XXX dropped option: fine_tune
import gc
import time
import argparse import argparse
import gc
import itertools import itertools
import math import math
import os import os
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util from library.config_ml_util import (
from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight from library.custom_train_functions import apply_snr_weight

View File

@ -1,31 +1,30 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse import argparse
import gc import gc
import importlib
import json
import math import math
import os import os
import random import random
import time import time
import json
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util import library.train_util as train_util
from library.train_util import ( from library.config_ml_util import (
DreamBoothDataset,
)
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight from library.train_util import (
DreamBoothDataset,
)
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する

View File

@ -1,31 +1,30 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse import argparse
import gc import gc
import importlib
import json
import math import math
import os import os
import random import random
import time import time
import json
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util import library.train_util as train_util
from library.train_util import ( from library.config_ml_util import (
DreamBoothDataset,
)
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight from library.train_util import (
DreamBoothDataset,
)
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する

View File

@ -1,24 +1,21 @@
import importlib
import argparse import argparse
import gc import gc
import math import math
import os import os
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util from library.config_ml_util import (
from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight from library.custom_train_functions import apply_snr_weight
imagenet_templates_small = [ imagenet_templates_small = [

View File

@ -1,24 +1,21 @@
import importlib
import argparse import argparse
import gc import gc
import math import math
import os import os
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util from library.config_ml_util import (
from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight from library.custom_train_functions import apply_snr_weight
imagenet_templates_small = [ imagenet_templates_small = [