All sorts of broken, but I need to commit this for now so I don't lose it. WIP: Using some OOP to reduce imports and centralize some code. No need to remake Tk Windows everywhere for example. Renamed common_gui to common_gui_functions.py to make some of the new code separation more obvious.
This commit is contained in:
parent
e5b83df675
commit
eef5becab8
@ -9,10 +9,11 @@ import math
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from library.common_gui import (
|
||||
from library.common_gui_functions import (
|
||||
get_folder_path,
|
||||
remove_doublequote,
|
||||
get_file_path,
|
||||
@ -28,9 +29,9 @@ from library.common_gui import (
|
||||
gradio_source_model,
|
||||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist, show_message_box, get_file_path_gradio_wrapper,
|
||||
check_if_model_exist, show_message_box,
|
||||
)
|
||||
from library.common_utilities import is_valid_config
|
||||
from library.common_utilities import CommonUtilities
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
)
|
||||
@ -230,12 +231,12 @@ def open_configuration(
|
||||
|
||||
if ask_for_file:
|
||||
print(f"File path: {file_path}")
|
||||
file_path = get_file_path_gradio_wrapper(file_path)
|
||||
file_path = get_file_path(file_path, filedialog_type="json")
|
||||
|
||||
if not file_path == '' and file_path is not None:
|
||||
with open(file_path, 'r') as f:
|
||||
my_data = json.load(f)
|
||||
if is_valid_config(my_data):
|
||||
if CommonUtilities.is_valid_config(my_data):
|
||||
print('Loading config...')
|
||||
my_data = update_my_data(my_data)
|
||||
else:
|
||||
@ -838,14 +839,14 @@ def dreambooth_tab(
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
lambda *args, **kwargs: open_configuration(*args),
|
||||
lambda *_args, **kwargs: open_configuration(*_args, **kwargs),
|
||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_load_config.click(
|
||||
lambda *args, **kwargs: open_configuration(*args),
|
||||
lambda *args, **kwargs: open_configuration(*args, **kwargs),
|
||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
show_progress=False,
|
||||
|
10
fine_tune.py
10
fine_tune.py
@ -5,22 +5,20 @@ import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.config_ml_util as config_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
from library.config_ml_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
|
||||
|
||||
|
@ -1,17 +1,18 @@
|
||||
import gradio as gr
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import pathlib
|
||||
import argparse
|
||||
from library.common_gui import (
|
||||
import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from library.common_gui_functions import (
|
||||
get_folder_path,
|
||||
get_file_path,
|
||||
get_saveasfile_path,
|
||||
save_inference_file,
|
||||
gradio_advanced_training,
|
||||
run_cmd_advanced_training,
|
||||
gradio_training,
|
||||
run_cmd_advanced_training,
|
||||
gradio_config,
|
||||
@ -20,15 +21,15 @@ from library.common_gui import (
|
||||
run_cmd_training,
|
||||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist, get_file_path_gradio_wrapper,
|
||||
check_if_model_exist,
|
||||
)
|
||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||
from library.tensorboard_gui import (
|
||||
gradio_tensorboard,
|
||||
start_tensorboard,
|
||||
stop_tensorboard,
|
||||
)
|
||||
from library.utilities import utilities_tab
|
||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
@ -231,7 +232,7 @@ def open_configuration(
|
||||
original_file_path = file_path
|
||||
|
||||
if ask_for_file:
|
||||
file_path = get_file_path_gradio_wrapper(file_path)
|
||||
file_path = get_file_path(file_path)
|
||||
|
||||
if not file_path == '' and file_path is not None:
|
||||
# load variables from JSON file
|
||||
|
13
kohya_gui.py
13
kohya_gui.py
@ -1,17 +1,18 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from dreambooth_gui import dreambooth_tab
|
||||
from finetune_gui import finetune_tab
|
||||
from textual_inversion_gui import ti_tab
|
||||
from library.utilities import utilities_tab
|
||||
from library.extract_lora_gui import gradio_extract_lora_tab
|
||||
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
||||
from library.merge_lora_gui import gradio_merge_lora_tab
|
||||
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||
from library.utilities import utilities_tab
|
||||
from lora_gui import lora_tab
|
||||
|
||||
|
||||
from textual_inversion_gui import ti_tab
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
|
@ -3,7 +3,7 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import get_folder_path, add_pre_postfix, find_replace
|
||||
from .common_gui_functions import get_folder_path, add_pre_postfix, find_replace
|
||||
|
||||
|
||||
def caption_images(
|
||||
|
@ -3,7 +3,7 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import get_folder_path, add_pre_postfix
|
||||
from .common_gui_functions import get_folder_path, add_pre_postfix
|
||||
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
|
||||
|
@ -1,12 +1,14 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from contextlib import contextmanager
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog, Tk
|
||||
|
||||
import easygui
|
||||
import gradio as gr
|
||||
|
||||
from library.gui_subprocesses import save_file_dialog
|
||||
from library.common_utilities import CommonUtilities
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
@ -37,6 +39,16 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
|
||||
FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tk_context():
|
||||
root = tk.Tk()
|
||||
root.withdraw()
|
||||
try:
|
||||
yield root
|
||||
finally:
|
||||
root.destroy()
|
||||
|
||||
|
||||
def open_file_dialog(initial_dir, initial_file, file_types="all"):
|
||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
@ -89,6 +101,7 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def update_my_data(my_data):
|
||||
# Update the optimizer based on the use_8bit_adam flag
|
||||
use_8bit_adam = my_data.get('use_8bit_adam', False)
|
||||
@ -138,28 +151,37 @@ def update_my_data(my_data):
|
||||
# # If no extension files were found, return False
|
||||
# return False
|
||||
|
||||
def get_file_path_gradio_wrapper(file_path, filedialog_type="all"):
|
||||
# def get_file_path_gradio_wrapper(file_path, filedialog_type="all"):
|
||||
# file_extension = os.path.splitext(file_path)[-1].lower()
|
||||
#
|
||||
# filetype_filters = {
|
||||
# 'db': ['.db'],
|
||||
# 'json': ['.json'],
|
||||
# 'lora': ['.pt', '.ckpt', '.safetensors'],
|
||||
# }
|
||||
#
|
||||
# # Find the appropriate filedialog_type based on the file extension
|
||||
# filedialog_type = 'all'
|
||||
# for key, extensions in filetype_filters.items():
|
||||
# if file_extension in extensions:
|
||||
# filedialog_type = key
|
||||
# break
|
||||
#
|
||||
# return get_file_path(file_path, filedialog_type)
|
||||
|
||||
|
||||
def get_file_path(file_path='', filedialog_type="lora"):
|
||||
file_extension = os.path.splitext(file_path)[-1].lower()
|
||||
|
||||
filetype_filters = {
|
||||
'db': ['.db'],
|
||||
'json': ['.json'],
|
||||
'lora': ['.pt', '.ckpt', '.safetensors'],
|
||||
}
|
||||
|
||||
# Find the appropriate filedialog_type based on the file extension
|
||||
filedialog_type = 'all'
|
||||
for key, extensions in filetype_filters.items():
|
||||
for key, extensions in CommonUtilities.file_filters.items():
|
||||
if file_extension in extensions:
|
||||
filedialog_type = key
|
||||
break
|
||||
|
||||
return get_file_path(file_path, filedialog_type)
|
||||
|
||||
|
||||
def get_file_path(file_path='', filedialog_type="lora"):
|
||||
current_file_path = file_path
|
||||
|
||||
print(f"File type: {filedialog_type}")
|
||||
initial_dir, initial_file = os.path.split(file_path)
|
||||
file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type)
|
||||
|
||||
@ -171,7 +193,6 @@ def get_file_path(file_path='', filedialog_type="lora"):
|
||||
return file_path
|
||||
|
||||
|
||||
|
||||
def get_any_file_path(file_path=''):
|
||||
current_file_path = file_path
|
||||
# print(f'current file path: {current_file_path}')
|
||||
@ -823,7 +844,7 @@ def gradio_advanced_training():
|
||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||
min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1)
|
||||
min_snr_gamma = gr.Slider(label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1)
|
||||
with gr.Row():
|
||||
bucket_no_upscale = gr.Checkbox(
|
||||
label="Don't upscale bucket resolution", value=True
|
@ -1,14 +1,24 @@
|
||||
def is_valid_config(data):
|
||||
# Check if the data is a dictionary
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
class CommonUtilities:
|
||||
file_filters = {
|
||||
"all": [("All files", "*.*")],
|
||||
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")],
|
||||
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")],
|
||||
"json": [("JSON files", "*.json")],
|
||||
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
|
||||
"directory": [],
|
||||
}
|
||||
|
||||
# Add checks for expected keys and valid values
|
||||
# For example, check if 'use_8bit_adam' is a boolean
|
||||
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
|
||||
return False
|
||||
def is_valid_config(self, data):
|
||||
# Check if the data is a dictionary
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
||||
# Add more checks for other keys as needed
|
||||
# Add checks for expected keys and valid values
|
||||
# For example, check if 'use_8bit_adam' is a boolean
|
||||
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
|
||||
return False
|
||||
|
||||
# If all checks pass, return True
|
||||
return True
|
||||
# Add more checks for other keys as needed
|
||||
|
||||
# If all checks pass, return True
|
||||
return True
|
||||
|
@ -1,13 +1,13 @@
|
||||
import argparse
|
||||
import functools
|
||||
import json
|
||||
import random
|
||||
from dataclasses import (
|
||||
asdict,
|
||||
dataclass,
|
||||
)
|
||||
import functools
|
||||
import random
|
||||
from textwrap import dedent, indent
|
||||
import json
|
||||
from pathlib import Path
|
||||
from textwrap import dedent, indent
|
||||
# from toolz import curry
|
||||
from typing import (
|
||||
List,
|
||||
@ -19,6 +19,7 @@ from typing import (
|
||||
|
||||
import toml
|
||||
import voluptuous
|
||||
from transformers import CLIPTokenizer
|
||||
from voluptuous import (
|
||||
Any,
|
||||
ExactSequence,
|
||||
@ -27,7 +28,6 @@ from voluptuous import (
|
||||
Required,
|
||||
Schema,
|
||||
)
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from . import train_util
|
||||
from .train_util import (
|
@ -4,7 +4,7 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import get_folder_path, get_file_path, get_file_path_gradio_wrapper
|
||||
from .common_gui_functions import get_folder_path, get_file_path
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
@ -180,8 +180,7 @@ def gradio_convert_model_tab():
|
||||
document_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_source_model_file.click(
|
||||
lambda input1, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.normpath(input1)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[source_model_input],
|
||||
outputs=source_model_input,
|
||||
show_progress=False,
|
||||
|
@ -4,7 +4,7 @@ import re
|
||||
import gradio as gr
|
||||
from easygui import boolbox
|
||||
|
||||
from .common_gui import get_folder_path
|
||||
from .common_gui_functions import get_folder_path
|
||||
|
||||
|
||||
# def select_folder():
|
||||
|
@ -3,7 +3,7 @@ import shutil
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import get_folder_path
|
||||
from .common_gui_functions import get_folder_path
|
||||
|
||||
|
||||
def copy_info_to_Folders_tab(training_folder):
|
||||
|
@ -3,8 +3,8 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
||||
from .common_gui_functions import (
|
||||
get_file_path, get_saveasfile_path,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
@ -90,8 +90,7 @@ def gradio_extract_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_model_tuned_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[model_tuned, model_ext, model_ext_name],
|
||||
outputs=model_tuned,
|
||||
show_progress=False,
|
||||
@ -107,7 +106,7 @@ def gradio_extract_lora_tab():
|
||||
)
|
||||
button_model_org_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[model_org, model_ext, model_ext_name],
|
||||
outputs=model_org,
|
||||
show_progress=False,
|
||||
|
@ -3,8 +3,8 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
||||
from .common_gui_functions import (
|
||||
get_file_path, get_saveasfile_path,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
@ -137,7 +137,7 @@ def gradio_extract_lycoris_locon_tab():
|
||||
)
|
||||
button_db_model_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[db_model, model_ext, model_ext_name],
|
||||
outputs=db_model,
|
||||
show_progress=False,
|
||||
@ -152,8 +152,7 @@ def gradio_extract_lycoris_locon_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_base_model_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[base_model, model_ext, model_ext_name],
|
||||
outputs=base_model,
|
||||
show_progress=False,
|
||||
|
@ -3,7 +3,7 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import get_folder_path, add_pre_postfix
|
||||
from .common_gui_functions import get_folder_path, add_pre_postfix
|
||||
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
|
||||
|
@ -1,84 +1,82 @@
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog, messagebox
|
||||
|
||||
|
||||
def open_file_dialog(initial_dir=None, initial_file=None, file_types="all"):
|
||||
file_type_filters = {
|
||||
"all": [("All files", "*.*")],
|
||||
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")],
|
||||
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")],
|
||||
"json": [("JSON files", "*.json")],
|
||||
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
|
||||
"directory": [],
|
||||
}
|
||||
|
||||
if file_types in file_type_filters:
|
||||
filters = file_type_filters[file_types]
|
||||
else:
|
||||
filters = file_type_filters["all"]
|
||||
|
||||
if file_types == "directory":
|
||||
return filedialog.askdirectory(initialdir=initial_dir)
|
||||
else:
|
||||
return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
|
||||
from library.common_gui_functions import tk_context
|
||||
from library.common_utilities import CommonUtilities
|
||||
|
||||
|
||||
def save_file_dialog(initial_dir, initial_file, files_type="all"):
|
||||
root = tk.Tk()
|
||||
root.withdraw()
|
||||
class TkGui:
|
||||
def __init__(self):
|
||||
self.file_types = None
|
||||
|
||||
filetypes_switch = {
|
||||
"all": [("All files", "*.*")],
|
||||
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.webm;*.flv;*.mov;*.wmv")],
|
||||
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff;*.ico")],
|
||||
"json": [("JSON files", "*.json")],
|
||||
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
|
||||
}
|
||||
def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"):
|
||||
print(f"File types: {self.file_types}")
|
||||
with tk_context():
|
||||
self.file_types = file_types
|
||||
if self.file_types in CommonUtilities.file_filters:
|
||||
filters = CommonUtilities.file_filters[self.file_types]
|
||||
else:
|
||||
filters = CommonUtilities.file_filters["all"]
|
||||
|
||||
filetypes = filetypes_switch.get(files_type, filetypes_switch["all"])
|
||||
save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filetypes,
|
||||
defaultextension=filetypes)
|
||||
if self.file_types == "directory":
|
||||
return filedialog.askdirectory(initialdir=initial_dir)
|
||||
else:
|
||||
return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
|
||||
|
||||
root.destroy()
|
||||
def save_file_dialog(self, initial_dir, initial_file, file_types="all"):
|
||||
self.file_types = file_types
|
||||
|
||||
return save_file_path
|
||||
# Use the tk_context function with the 'with' statement
|
||||
with tk_context():
|
||||
if self.file_types in CommonUtilities.file_filters:
|
||||
filters = CommonUtilities.file_filters[self.file_types]
|
||||
else:
|
||||
filters = CommonUtilities.file_filters["all"]
|
||||
|
||||
save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file,
|
||||
filetypes=filters, defaultextension=".safetensors")
|
||||
|
||||
def show_message_box(_message, _title="Message", _level="info"):
|
||||
root = tk.Tk()
|
||||
root.withdraw()
|
||||
return save_file_path
|
||||
|
||||
message_type = {
|
||||
"warning": messagebox.showwarning,
|
||||
"error": messagebox.showerror,
|
||||
"info": messagebox.showinfo,
|
||||
"question": messagebox.askquestion,
|
||||
"okcancel": messagebox.askokcancel,
|
||||
"retrycancel": messagebox.askretrycancel,
|
||||
"yesno": messagebox.askyesno,
|
||||
"yesnocancel": messagebox.askyesnocancel
|
||||
}
|
||||
def show_message_box(_message, _title="Message", _level="info"):
|
||||
with tk_context():
|
||||
message_type = {
|
||||
"warning": messagebox.showwarning,
|
||||
"error": messagebox.showerror,
|
||||
"info": messagebox.showinfo,
|
||||
"question": messagebox.askquestion,
|
||||
"okcancel": messagebox.askokcancel,
|
||||
"retrycancel": messagebox.askretrycancel,
|
||||
"yesno": messagebox.askyesno,
|
||||
"yesnocancel": messagebox.askyesnocancel
|
||||
}
|
||||
|
||||
if _level in message_type:
|
||||
message_type[_level](_title, _message)
|
||||
else:
|
||||
messagebox.showinfo(_title, _message)
|
||||
|
||||
root.destroy()
|
||||
if _level in message_type:
|
||||
message_type[_level](_title, _message)
|
||||
else:
|
||||
messagebox.showinfo(_title, _message)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mode = sys.argv[1]
|
||||
try:
|
||||
mode = sys.argv[1]
|
||||
|
||||
if mode == 'file_dialog':
|
||||
starting_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
starting_file = sys.argv[3] if len(sys.argv) > 3 else None
|
||||
file_class = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
file_path = open_file_dialog(starting_dir, starting_file, file_class)
|
||||
print(file_path)
|
||||
if mode == 'file_dialog':
|
||||
starting_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
starting_file = sys.argv[3] if len(sys.argv) > 3 else None
|
||||
file_class = sys.argv[4] if len(sys.argv) > 4 else None # Update this to sys.argv[4]
|
||||
gui = TkGui()
|
||||
file_path = gui.open_file_dialog(starting_dir, starting_file, file_class)
|
||||
print(file_path) # Make sure to print the result
|
||||
|
||||
elif mode == 'msgbox':
|
||||
message = sys.argv[2]
|
||||
title = sys.argv[3] if len(sys.argv) > 3 else ""
|
||||
show_message_box(message, title)
|
||||
elif mode == 'msgbox':
|
||||
message = sys.argv[2]
|
||||
title = sys.argv[3] if len(sys.argv) > 3 else ""
|
||||
gui = TkGui()
|
||||
gui.show_message_box(message, title)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
@ -3,8 +3,8 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
||||
from .common_gui_functions import (
|
||||
get_file_path, get_saveasfile_path,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
@ -81,8 +81,7 @@ def gradio_merge_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_a_model,
|
||||
show_progress=False,
|
||||
@ -97,8 +96,7 @@ def gradio_merge_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_b_model,
|
||||
show_progress=False,
|
||||
|
@ -3,7 +3,7 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper
|
||||
from .common_gui_functions import get_file_path, get_saveasfile_path
|
||||
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
@ -96,8 +96,7 @@ def gradio_resize_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[model, lora_ext, lora_ext_name],
|
||||
outputs=model,
|
||||
show_progress=False,
|
||||
|
@ -3,8 +3,8 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
||||
from .common_gui_functions import (
|
||||
get_file_path, get_saveasfile_path,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
@ -87,8 +87,7 @@ def gradio_svd_merge_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_a_model,
|
||||
show_progress=False,
|
||||
@ -103,8 +102,7 @@ def gradio_svd_merge_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_b_model,
|
||||
show_progress=False,
|
||||
|
@ -3,8 +3,8 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path, get_file_path_gradio_wrapper,
|
||||
from .common_gui_functions import (
|
||||
get_file_path,
|
||||
)
|
||||
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
@ -68,8 +68,7 @@ def gradio_verify_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_model_file.click(
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
inputs=[lora_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_model,
|
||||
show_progress=False,
|
||||
|
@ -3,7 +3,7 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import get_folder_path
|
||||
from .common_gui_functions import get_folder_path
|
||||
|
||||
|
||||
def replace_underscore_with_space(folder_path, file_extension):
|
||||
|
@ -12,7 +12,7 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from library.common_gui import (
|
||||
from library.common_gui_functions import (
|
||||
get_folder_path,
|
||||
remove_doublequote,
|
||||
get_file_path,
|
||||
@ -28,7 +28,7 @@ from library.common_gui import (
|
||||
run_cmd_training,
|
||||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist, show_message_box, get_file_path_gradio_wrapper,
|
||||
check_if_model_exist, show_message_box,
|
||||
)
|
||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
@ -254,7 +254,7 @@ def open_configuration(
|
||||
original_file_path = file_path
|
||||
|
||||
if ask_for_file:
|
||||
file_path = get_file_path_gradio_wrapper(file_path)
|
||||
file_path = get_file_path(file_path)
|
||||
|
||||
if not file_path == '' and not file_path == None:
|
||||
# load variables from JSON file
|
||||
|
@ -27,6 +27,6 @@ huggingface-hub==0.12.0; sys_platform != 'darwin'
|
||||
huggingface-hub==0.13.0; sys_platform == 'darwin'
|
||||
tensorflow==2.10.1; sys_platform != 'darwin'
|
||||
# For locon support
|
||||
lycoris_lora==0.1.2
|
||||
lycoris_lora==0.1.4
|
||||
# for kohya_ss library
|
||||
.
|
@ -12,7 +12,7 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from library.common_gui import (
|
||||
from library.common_gui_functions import (
|
||||
get_folder_path,
|
||||
remove_doublequote,
|
||||
get_file_path,
|
||||
@ -28,7 +28,7 @@ from library.common_gui import (
|
||||
gradio_source_model,
|
||||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist, get_file_path_gradio_wrapper,
|
||||
check_if_model_exist,
|
||||
)
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
@ -240,9 +240,9 @@ def open_configuration(
|
||||
original_file_path = file_path
|
||||
|
||||
if ask_for_file:
|
||||
file_path = get_file_path_gradio_wrapper(file_path)
|
||||
file_path = get_file_path(file_path)
|
||||
|
||||
if not file_path == '' and not file_path == None:
|
||||
if not file_path == '' and file_path is not None:
|
||||
# load variables from JSON file
|
||||
with open(file_path, 'r') as f:
|
||||
my_data = json.load(f)
|
||||
@ -673,7 +673,7 @@ def ti_tab(
|
||||
)
|
||||
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
||||
weights_file_input.click(
|
||||
lambda *args, **kwargs: get_file_path_gradio_wrapper,
|
||||
lambda *args, **kwargs: get_file_path(*args),
|
||||
outputs=weights,
|
||||
show_progress=False,
|
||||
)
|
||||
|
@ -1,29 +1,27 @@
|
||||
# DreamBooth training
|
||||
# XXX dropped option: fine_tune
|
||||
|
||||
import gc
|
||||
import time
|
||||
import argparse
|
||||
import gc
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.config_ml_util as config_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
from library.config_ml_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
|
13
train_db.py
13
train_db.py
@ -1,28 +1,25 @@
|
||||
# DreamBooth training
|
||||
# XXX dropped option: fine_tune
|
||||
|
||||
import gc
|
||||
import time
|
||||
import argparse
|
||||
import gc
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.config_ml_util as config_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
from library.config_ml_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
|
||||
|
||||
|
@ -1,31 +1,30 @@
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
import importlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.config_ml_util as config_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
import library.train_util as train_util
|
||||
from library.train_util import (
|
||||
DreamBoothDataset,
|
||||
)
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
from library.config_ml_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.train_util import (
|
||||
DreamBoothDataset,
|
||||
)
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
|
@ -1,31 +1,30 @@
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
import importlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.config_ml_util as config_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
import library.train_util as train_util
|
||||
from library.train_util import (
|
||||
DreamBoothDataset,
|
||||
)
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
from library.config_ml_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.train_util import (
|
||||
DreamBoothDataset,
|
||||
)
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
|
@ -1,24 +1,21 @@
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.config_ml_util as config_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
from library.config_ml_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
|
||||
imagenet_templates_small = [
|
||||
|
@ -1,24 +1,21 @@
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.config_ml_util as config_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
from library.config_ml_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
|
||||
imagenet_templates_small = [
|
||||
|
Loading…
Reference in New Issue
Block a user