Merge pull request #515 from bmaltais/revert-483-macos_gui

Revert "macOS GUI functionality, sub-processed GUI components"
This commit is contained in:
bmaltais 2023-04-01 21:29:01 -04:00 committed by GitHub
commit 55d34ab733
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 597 additions and 936 deletions

239
.gitignore vendored
View File

@ -1,243 +1,12 @@
# Kohya_SS Specifics
venv
__pycache__
cudnn_windows
.vscode
*.egg-info
build
wd14_tagger_model
.DS_Store
locon
gui-user.bat
gui-user.ps1
*.whl*
.idea
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# AWS User-specific
.idea/**/aws.xml
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# SonarLint plugin
.idea/sonarlint/
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
library/__init__.py

View File

@ -3,17 +3,14 @@
# v3: Add new Utilities tab for Dreambooth folder preparation
# v3.1: Adding captionning of images to utilities
import argparse
import gradio as gr
import json
import math
import os
import pathlib
import subprocess
import sys
import gradio as gr
from library.common_gui_functions import (
import pathlib
import argparse
from library.common_gui import (
get_folder_path,
remove_doublequote,
get_file_path,
@ -29,19 +26,19 @@ from library.common_gui_functions import (
gradio_source_model,
# set_legacy_8bitadam,
update_my_data,
check_if_model_exist, show_message_box,
check_if_model_exist,
)
from library.common_utilities import CommonUtilities
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from library.tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
stop_tensorboard,
)
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.utilities import utilities_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from easygui import msgbox
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@ -125,12 +122,12 @@ def save_configuration(
file_path = get_saveasfile_path(file_path)
else:
print('Save...')
if file_path is None or file_path == '':
if file_path == None or file_path == '':
file_path = get_saveasfile_path(file_path)
# print(file_path)
if file_path is None or file_path == '':
if file_path == None or file_path == '':
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
# Return the values of the variables as a dictionary
@ -222,10 +219,6 @@ def open_configuration(
vae_batch_size,
min_snr_gamma,
):
print("open_configuration called")
print(f"locals length: {len(locals())}")
print(f"locals: {locals()}")
# Get list of function parameters and values
parameters = list(locals().items())
@ -233,25 +226,18 @@ def open_configuration(
original_file_path = file_path
if ask_for_file and file_path is not None:
print(f"File path: {file_path}")
file_path, canceled = get_file_path(file_path=file_path, filedialog_type="json")
if ask_for_file:
file_path = get_file_path(file_path)
if canceled:
return (None,) + (None,) * (len(parameters) - 2)
if not file_path == '' and file_path is not None:
if not file_path == '' and not file_path == None:
# load variables from JSON file
with open(file_path, 'r') as f:
my_data = json.load(f)
if CommonUtilities.is_valid_config(my_data):
print('Loading config...')
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
my_data = update_my_data(my_data)
else:
print("Invalid configuration file.")
my_data = {}
show_message_box("Invalid configuration file.")
else:
file_path = original_file_path
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
my_data = {}
values = [file_path]
@ -259,8 +245,6 @@ def open_configuration(
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value))
# Print the number of returned values
print(f"Returning: {values}")
return tuple(values)
@ -327,24 +311,24 @@ def train_model(
min_snr_gamma,
):
if pretrained_model_name_or_path == '':
show_message_box('Source model information is missing')
msgbox('Source model information is missing')
return
if train_data_dir == '':
show_message_box('Image folder path is missing')
msgbox('Image folder path is missing')
return
if not os.path.exists(train_data_dir):
show_message_box('Image folder does not exist')
msgbox('Image folder does not exist')
return
if reg_data_dir != '':
if not os.path.exists(reg_data_dir):
show_message_box('Regularisation folder does not exist')
msgbox('Regularisation folder does not exist')
return
if output_dir == '':
show_message_box('Output folder path is missing')
msgbox('Output folder path is missing')
return
if check_if_model_exist(output_name, output_dir, save_model_as):
@ -848,20 +832,19 @@ def dreambooth_tab(
]
button_open_config.click(
lambda *_args, **kwargs: open_configuration(*_args, **kwargs),
open_configuration,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_load_config.click(
lambda *args, **kwargs: (print("Lambda called"), open_configuration(*args, **kwargs)),
inputs=[dummy_db_true, config_file_name] + settings_list,
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
# Print the number of expected outputs
print(f"Number of expected outputs: {len([config_file_name] + settings_list)}")
button_save_config.click(
save_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,

View File

@ -5,20 +5,22 @@ import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util
from library.config_ml_util import (
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight

View File

@ -1,18 +1,17 @@
import argparse
import gradio as gr
import json
import math
import os
import pathlib
import subprocess
import gradio as gr
from library.common_gui_functions import (
import pathlib
import argparse
from library.common_gui import (
get_folder_path,
get_file_path,
get_saveasfile_path,
save_inference_file,
gradio_advanced_training,
run_cmd_advanced_training,
gradio_training,
run_cmd_advanced_training,
gradio_config,
@ -23,13 +22,13 @@ from library.common_gui_functions import (
update_my_data,
check_if_model_exist,
)
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from library.tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
stop_tensorboard,
)
from library.utilities import utilities_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@ -234,7 +233,7 @@ def open_configuration(
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and file_path is not None:
if not file_path == '' and not file_path == None:
# load variables from JSON file
with open(file_path, 'r') as f:
my_data = json.load(f)
@ -800,14 +799,14 @@ def finetune_tab():
button_run.click(train_model, inputs=settings_list)
button_open_config.click(
lambda *args, **kwargs: open_configuration(),
open_configuration,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_load_config.click(
lambda *args, **kwargs: open_configuration(),
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,

View File

@ -1,18 +1,17 @@
import argparse
import os
from pathlib import Path
import gradio as gr
import os
import argparse
from dreambooth_gui import dreambooth_tab
from finetune_gui import finetune_tab
from textual_inversion_gui import ti_tab
from library.utilities import utilities_tab
from library.extract_lora_gui import gradio_extract_lora_tab
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
from library.merge_lora_gui import gradio_merge_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab
from library.utilities import utilities_tab
from lora_gui import lora_tab
from textual_inversion_gui import ti_tab
def UI(**kwargs):
css = ''

View File

@ -1,9 +1,8 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import get_folder_path, add_pre_postfix, find_replace
from easygui import msgbox
import subprocess
from .common_gui import get_folder_path, add_pre_postfix, find_replace
import os
def caption_images(
@ -18,11 +17,11 @@ def caption_images(
):
# Check for images_dir
if not images_dir:
show_message_box('Image folder is missing...')
msgbox('Image folder is missing...')
return
if not caption_ext:
show_message_box('Please provide an extension for the caption files.')
msgbox('Please provide an extension for the caption files.')
return
if caption_text:
@ -61,7 +60,7 @@ def caption_images(
)
else:
if prefix or postfix:
show_message_box(
msgbox(
'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
)

View File

@ -1,9 +1,8 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import get_folder_path, add_pre_postfix
from easygui import msgbox
import subprocess
import os
from .common_gui import get_folder_path, add_pre_postfix
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
@ -22,16 +21,16 @@ def caption_images(
):
# Check for caption_text_input
# if caption_text_input == "":
# show_message_box("Caption text is missing...")
# msgbox("Caption text is missing...")
# return
# Check for images_dir_input
if train_data_dir == '':
show_message_box('Image folder is missing...')
msgbox('Image folder is missing...')
return
if caption_file_ext == '':
show_message_box('Please provide an extension for the caption files.')
msgbox('Please provide an extension for the caption files.')
return
print(f'Captioning files in {train_data_dir}...')

View File

@ -1,14 +1,9 @@
import os
import shutil
import subprocess
from contextlib import contextmanager
import tkinter as tk
from tkinter import filedialog, Tk
import easygui
from easygui import msgbox
import os
import gradio as gr
from library.common_utilities import CommonUtilities
import easygui
import shutil
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@ -39,41 +34,6 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
@contextmanager
def tk_context():
root = tk.Tk()
root.withdraw()
try:
yield root
finally:
root.destroy()
def open_file_dialog(initial_dir, initial_file, file_types="all"):
current_directory = os.path.dirname(os.path.abspath(__file__))
args = ["python", f"{current_directory}/gui_subprocesses.py", "file_dialog"]
if initial_dir:
args.append(initial_dir)
if initial_file:
args.append(initial_file)
if file_types:
args.append(file_types)
file_path = subprocess.check_output(args).decode("utf-8").strip()
return file_path
def show_message_box(message, title=""):
current_directory = os.path.dirname(os.path.abspath(__file__))
args = ["python", f"{current_directory}/gui_subprocesses.py", "msgbox", message]
if title:
args.append(title)
subprocess.run(args)
def check_if_model_exist(output_name, output_dir, save_model_as):
if save_model_as in ['diffusers', 'diffusers_safetendors']:
ckpt_folder = os.path.join(output_dir, output_name)
@ -142,6 +102,11 @@ def update_my_data(my_data):
return my_data
def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name)
# def has_ext_files(directory, extension):
# # Iterate through all the files in the directory
# for file in os.listdir(directory):
@ -152,29 +117,58 @@ def update_my_data(my_data):
# return False
def get_file_path(file_path, initial_dir=None, initial_file=None, filedialog_type="lora"):
file_extension = os.path.splitext(file_path)[-1].lower()
# Find the appropriate filedialog_type based on the file extension
for key, extensions in CommonUtilities.file_filters.items():
if file_extension in extensions:
filedialog_type = key
break
current_file_path = file_path
initial_dir, initial_file = os.path.split(file_path)
result = open_file_dialog(initial_dir=initial_dir, initial_file=initial_file, file_types=filedialog_type)
file_path, canceled = result[:2]
return file_path, canceled
def get_any_file_path(file_path=''):
def get_file_path(
file_path='', default_extension='.json', extension_name='Config files'
):
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = os.path.split(file_path)
file_path = open_file_dialog(initial_dir, initial_file, "all")
initial_dir, initial_file = get_dir_and_file(file_path)
# Create a hidden Tkinter root window
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
# Show the open file dialog and get the selected file path
file_path = filedialog.askopenfilename(
filetypes=(
(extension_name, f'*{default_extension}'),
('All files', '*.*'),
),
defaultextension=default_extension,
initialfile=initial_file,
initialdir=initial_dir,
)
# Destroy the hidden root window
root.destroy()
# If no file is selected, use the current file path
if not file_path:
file_path = current_file_path
current_file_path = file_path
# print(f'current file path: {current_file_path}')
return file_path
def get_any_file_path(file_path=''):
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
file_path = filedialog.askopenfilename(
initialdir=initial_dir,
initialfile=initial_file,
)
root.destroy()
if file_path == '':
file_path = current_file_path
@ -183,7 +177,7 @@ def get_any_file_path(file_path=''):
def remove_doublequote(file_path):
if file_path is not None:
if file_path != None:
file_path = file_path.replace('"', '')
return file_path
@ -202,11 +196,17 @@ def remove_doublequote(file_path):
# )
def get_folder_path(folder_path='', filedialog_type="directory"):
def get_folder_path(folder_path=''):
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
current_folder_path = folder_path
initial_dir, initial_file = os.path.split(folder_path)
file_path = open_file_dialog(initial_dir, initial_file, filedialog_type)
initial_dir, initial_file = get_dir_and_file(folder_path)
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
folder_path = filedialog.askdirectory(initialdir=initial_dir)
root.destroy()
if folder_path == '':
folder_path = current_folder_path
@ -215,19 +215,38 @@ def get_folder_path(folder_path='', filedialog_type="directory"):
def get_saveasfile_path(
file_path='', filedialog_type="json"
file_path='', defaultextension='.json', extension_name='Config files'
):
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = os.path.split(file_path)
save_file_path = save_file_dialog(initial_dir, initial_file, filedialog_type)
initial_dir, initial_file = get_dir_and_file(file_path)
if save_file_path is None:
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
save_file_path = filedialog.asksaveasfile(
filetypes=(
(f'{extension_name}', f'{defaultextension}'),
('All files', '*'),
),
defaultextension=defaultextension,
initialdir=initial_dir,
initialfile=initial_file,
)
root.destroy()
# print(save_file_path)
if save_file_path == None:
file_path = current_file_path
else:
print(save_file_path.name)
file_path = save_file_path.name
# print(file_path)
return file_path
@ -341,7 +360,7 @@ def find_replace(
print('Running caption find/replace')
if not has_ext_files(folder_path, caption_file_ext):
show_message_box(
msgbox(
f'No files with extension {caption_file_ext} were found in {folder_path}...'
)
return
@ -367,7 +386,7 @@ def find_replace(
def color_aug_changed(color_aug):
if color_aug:
show_message_box(
msgbox(
'Disabling "Cache latent" because "Color augmentation" has been selected...'
)
return gr.Checkbox.update(value=False, interactive=False)
@ -467,6 +486,7 @@ def set_model_list(
v2,
v_parameterization,
):
if not pretrained_model_name_or_path in ALL_PRESET_MODELS:
model_list = 'custom'
else:

View File

@ -1,24 +0,0 @@
class CommonUtilities:
file_filters = {
"all": [("All files", "*.*")],
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")],
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")],
"json": [("JSON files", "*.json")],
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
"directory": [],
}
def is_valid_config(self, data):
# Check if the data is a dictionary
if not isinstance(data, dict):
return False
# Add checks for expected keys and valid values
# For example, check if 'use_8bit_adam' is a boolean
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
return False
# Add more checks for other keys as needed
# If all checks pass, return True
return True

View File

@ -1,13 +1,13 @@
import argparse
import functools
import json
import random
from dataclasses import (
asdict,
dataclass,
)
from pathlib import Path
import functools
import random
from textwrap import dedent, indent
import json
from pathlib import Path
# from toolz import curry
from typing import (
List,
@ -19,7 +19,6 @@ from typing import (
import toml
import voluptuous
from transformers import CLIPTokenizer
from voluptuous import (
Any,
ExactSequence,
@ -28,6 +27,7 @@ from voluptuous import (
Required,
Schema,
)
from transformers import CLIPTokenizer
from . import train_util
from .train_util import (

View File

@ -1,10 +1,9 @@
import gradio as gr
from easygui import msgbox
import subprocess
import os
import shutil
import subprocess
import gradio as gr
from .common_gui_functions import get_folder_path, get_file_path
from .common_gui import get_folder_path, get_file_path
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@ -23,7 +22,7 @@ def convert_model(
):
# Check for caption_text_input
if source_model_type == '':
show_message_box('Invalid source model type')
msgbox('Invalid source model type')
return
# Check if source model exist
@ -32,14 +31,14 @@ def convert_model(
elif os.path.isdir(source_model_input):
print('The provided model is a folder')
else:
show_message_box('The provided source model is neither a file nor a folder')
msgbox('The provided source model is neither a file nor a folder')
return
# Check if source model exist
if os.path.isdir(target_model_folder_input):
print('The provided model folder exist')
else:
show_message_box('The provided target folder does not exist')
msgbox('The provided target folder does not exist')
return
run_cmd = f'{PYTHON} "tools/convert_diffusers20_original_sd.py"'
@ -180,7 +179,7 @@ def gradio_convert_model_tab():
document_symbol, elem_id='open_folder_small'
)
button_source_model_file.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[source_model_input],
outputs=source_model_input,
show_progress=False,

View File

@ -1,11 +1,8 @@
import os
import re
import gradio as gr
from easygui import boolbox
from .common_gui_functions import get_folder_path
from easygui import msgbox, boolbox
from .common_gui import get_folder_path
# def select_folder():
# # Open a file dialog to select a directory
@ -19,14 +16,14 @@ def dataset_balancing(concept_repeats, folder, insecure):
if not concept_repeats > 0:
# Display an error message if the total number of repeats is not a valid integer
show_message_box('Please enter a valid integer for the total number of repeats.')
msgbox('Please enter a valid integer for the total number of repeats.')
return
concept_repeats = int(concept_repeats)
# Check if folder exist
if folder == '' or not os.path.isdir(folder):
show_message_box('Please enter a valid folder for balancing.')
msgbox('Please enter a valid folder for balancing.')
return
pattern = re.compile(r'^\d+_.+$')
@ -88,7 +85,7 @@ def dataset_balancing(concept_repeats, folder, insecure):
f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...'
)
show_message_box('Dataset balancing completed...')
msgbox('Dataset balancing completed...')
def warning(insecure):

View File

@ -1,9 +1,8 @@
import os
import shutil
import gradio as gr
from .common_gui_functions import get_folder_path
from easygui import diropenbox, msgbox
from .common_gui import get_folder_path
import shutil
import os
def copy_info_to_Folders_tab(training_folder):
@ -40,12 +39,12 @@ def dreambooth_folder_preparation(
# Check for instance prompt
if util_instance_prompt_input == '':
show_message_box('Instance prompt missing...')
msgbox('Instance prompt missing...')
return
# Check for class prompt
if util_class_prompt_input == '':
show_message_box('Class prompt missing...')
msgbox('Class prompt missing...')
return
# Create the training_dir path

View File

@ -1,10 +1,11 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import (
get_file_path, get_saveasfile_path,
from easygui import msgbox
import subprocess
import os
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
folder_symbol = '\U0001f4c2' # 📂
@ -26,20 +27,20 @@ def extract_lora(
):
# Check for caption_text_input
if model_tuned == '':
show_message_box('Invalid finetuned model file')
msgbox('Invalid finetuned model file')
return
if model_org == '':
show_message_box('Invalid base model file')
msgbox('Invalid base model file')
return
# Check if source model exist
if not os.path.isfile(model_tuned):
show_message_box('The provided finetuned model is not a file')
msgbox('The provided finetuned model is not a file')
return
if not os.path.isfile(model_org):
show_message_box('The provided base model is not a file')
msgbox('The provided base model is not a file')
return
run_cmd = (
@ -90,7 +91,7 @@ def gradio_extract_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_model_tuned_file.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[model_tuned, model_ext, model_ext_name],
outputs=model_tuned,
show_progress=False,
@ -105,8 +106,7 @@ def gradio_extract_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_model_org_file.click(
lambda input1, input2, input3, *args, **kwargs:
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[model_org, model_ext, model_ext_name],
outputs=model_org,
show_progress=False,
@ -121,7 +121,7 @@ def gradio_extract_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_save_to.click(
get_saveasfile_path,
get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
show_progress=False,

View File

@ -1,10 +1,11 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import (
get_file_path, get_saveasfile_path,
from easygui import msgbox
import subprocess
import os
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
folder_symbol = '\U0001f4c2' # 📂
@ -35,20 +36,20 @@ def extract_lycoris_locon(
):
# Check for caption_text_input
if db_model == '':
show_message_box('Invalid finetuned model file')
msgbox('Invalid finetuned model file')
return
if base_model == '':
show_message_box('Invalid base model file')
msgbox('Invalid base model file')
return
# Check if source model exist
if not os.path.isfile(db_model):
show_message_box('The provided finetuned model is not a file')
msgbox('The provided finetuned model is not a file')
return
if not os.path.isfile(base_model):
show_message_box('The provided base model is not a file')
msgbox('The provided base model is not a file')
return
run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
@ -136,8 +137,7 @@ def gradio_extract_lycoris_locon_tab():
folder_symbol, elem_id='open_folder_small'
)
button_db_model_file.click(
lambda input1, input2, input3, *args, **kwargs:
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[db_model, model_ext, model_ext_name],
outputs=db_model,
show_progress=False,
@ -152,7 +152,7 @@ def gradio_extract_lycoris_locon_tab():
folder_symbol, elem_id='open_folder_small'
)
button_base_model_file.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[base_model, model_ext, model_ext_name],
outputs=base_model,
show_progress=False,
@ -167,7 +167,7 @@ def gradio_extract_lycoris_locon_tab():
folder_symbol, elem_id='open_folder_small'
)
button_output_name.click(
get_saveasfile_path,
get_saveasfilename_path,
inputs=[output_name, lora_ext, lora_ext_name],
outputs=output_name,
show_progress=False,

View File

@ -1,9 +1,8 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import get_folder_path, add_pre_postfix
from easygui import msgbox
import subprocess
import os
from .common_gui import get_folder_path, add_pre_postfix
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
@ -20,11 +19,11 @@ def caption_images(
):
# Check for images_dir_input
if train_data_dir == '':
show_message_box('Image folder is missing...')
msgbox('Image folder is missing...')
return
if caption_ext == '':
show_message_box('Please provide an extension for the caption files.')
msgbox('Please provide an extension for the caption files.')
return
print(f'GIT captioning files in {train_data_dir}...')

View File

@ -1,86 +0,0 @@
import os
import pathlib
import sys
import tkinter as tk
from tkinter import filedialog, messagebox
from library.common_gui_functions import tk_context
from library.common_utilities import CommonUtilities
class TkGui:
def __init__(self):
self.file_types = None
def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"):
with tk_context():
self.file_types = file_types
if self.file_types in CommonUtilities.file_filters:
filters = CommonUtilities.file_filters[self.file_types]
else:
filters = CommonUtilities.file_filters["all"]
if self.file_types == "directory":
result = filedialog.askdirectory(initialdir=initial_dir)
else:
result = filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
# Return a tuple (file_path, canceled)
# file_path: the selected file path or an empty string if no file is selected
# canceled: True if the user pressed the cancel button, False otherwise
return result, result == ""
def save_file_dialog(self, initial_dir, initial_file, file_types="all"):
self.file_types = file_types
# Use the tk_context function with the 'with' statement
with tk_context():
if self.file_types in CommonUtilities.file_filters:
filters = CommonUtilities.file_filters[self.file_types]
else:
filters = CommonUtilities.file_filters["all"]
save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file,
filetypes=filters, defaultextension=".safetensors")
return save_file_path
def show_message_box(_message, _title="Message", _level="info"):
with tk_context():
message_type = {
"warning": messagebox.showwarning,
"error": messagebox.showerror,
"info": messagebox.showinfo,
"question": messagebox.askquestion,
"okcancel": messagebox.askokcancel,
"retrycancel": messagebox.askretrycancel,
"yesno": messagebox.askyesno,
"yesnocancel": messagebox.askyesnocancel
}
if _level in message_type:
message_type[_level](_title, _message)
else:
messagebox.showinfo(_title, _message)
if __name__ == '__main__':
try:
mode = sys.argv[1]
if mode == 'file_dialog':
starting_dir = sys.argv[2] if len(sys.argv) > 2 else None
starting_file = sys.argv[3] if len(sys.argv) > 3 else None
file_class = sys.argv[4] if len(sys.argv) > 4 else None # Update this to sys.argv[4]
gui = TkGui()
file_path = gui.open_file_dialog(starting_dir, starting_file, file_class)
print(file_path) # Make sure to print the result
elif mode == 'msgbox':
message = sys.argv[2]
title = sys.argv[3] if len(sys.argv) > 3 else ""
gui = TkGui()
gui.show_message_box(message, title)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)

View File

@ -1,10 +1,11 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import (
get_file_path, get_saveasfile_path,
from easygui import msgbox
import subprocess
import os
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
folder_symbol = '\U0001f4c2' # 📂
@ -24,20 +25,20 @@ def merge_lora(
):
# Check for caption_text_input
if lora_a_model == '':
show_message_box('Invalid model A file')
msgbox('Invalid model A file')
return
if lora_b_model == '':
show_message_box('Invalid model B file')
msgbox('Invalid model B file')
return
# Check if source model exist
if not os.path.isfile(lora_a_model):
show_message_box('The provided model A is not a file')
msgbox('The provided model A is not a file')
return
if not os.path.isfile(lora_b_model):
show_message_box('The provided model B is not a file')
msgbox('The provided model B is not a file')
return
ratio_a = ratio
@ -81,7 +82,7 @@ def gradio_merge_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_lora_a_model_file.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model,
show_progress=False,
@ -96,7 +97,7 @@ def gradio_merge_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_lora_b_model_file.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[lora_b_model, lora_ext, lora_ext_name],
outputs=lora_b_model,
show_progress=False,
@ -121,7 +122,7 @@ def gradio_merge_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_save_to.click(
get_saveasfile_path,
get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
show_progress=False,

View File

@ -1,9 +1,8 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import get_file_path, get_saveasfile_path
from easygui import msgbox
import subprocess
import os
from .common_gui import get_saveasfilename_path, get_file_path
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
folder_symbol = '\U0001f4c2' # 📂
@ -24,24 +23,24 @@ def resize_lora(
):
# Check for caption_text_input
if model == '':
show_message_box('Invalid model file')
msgbox('Invalid model file')
return
# Check if source model exist
if not os.path.isfile(model):
show_message_box('The provided model is not a file')
msgbox('The provided model is not a file')
return
if dynamic_method == 'sv_ratio':
if float(dynamic_param) < 2:
show_message_box(
msgbox(
f'Dynamic parameter for {dynamic_method} need to be 2 or greater...'
)
return
if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
if float(dynamic_param) < 0 or float(dynamic_param) > 1:
show_message_box(
msgbox(
f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...'
)
return
@ -96,7 +95,7 @@ def gradio_resize_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_lora_a_model_file.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[model, lora_ext, lora_ext_name],
outputs=model,
show_progress=False,
@ -135,7 +134,7 @@ def gradio_resize_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_save_to.click(
get_saveasfile_path,
get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
show_progress=False,

View File

@ -1,6 +1,7 @@
import tempfile
import os
import gradio as gr
from easygui import msgbox
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄

View File

@ -1,10 +1,11 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import (
get_file_path, get_saveasfile_path,
from easygui import msgbox
import subprocess
import os
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
folder_symbol = '\U0001f4c2' # 📂
@ -27,20 +28,20 @@ def svd_merge_lora(
):
# Check for caption_text_input
if lora_a_model == '':
show_message_box('Invalid model A file')
msgbox('Invalid model A file')
return
if lora_b_model == '':
show_message_box('Invalid model B file')
msgbox('Invalid model B file')
return
# Check if source model exist
if not os.path.isfile(lora_a_model):
show_message_box('The provided model A is not a file')
msgbox('The provided model A is not a file')
return
if not os.path.isfile(lora_b_model):
show_message_box('The provided model B is not a file')
msgbox('The provided model B is not a file')
return
ratio_a = ratio
@ -87,7 +88,7 @@ def gradio_svd_merge_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_lora_a_model_file.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model,
show_progress=False,
@ -102,7 +103,7 @@ def gradio_svd_merge_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_lora_b_model_file.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[lora_b_model, lora_ext, lora_ext_name],
outputs=lora_b_model,
show_progress=False,
@ -143,7 +144,7 @@ def gradio_svd_merge_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_save_to.click(
get_saveasfile_path,
get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
show_progress=False,

View File

@ -1,9 +1,9 @@
import os
import gradio as gr
from easygui import msgbox
import subprocess
import time
import gradio as gr
tensorboard_proc = None # I know... bad but heh
TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe'
@ -13,7 +13,7 @@ def start_tensorboard(logging_dir):
if not os.listdir(logging_dir):
print('Error: log folder is empty')
show_message_box(msg='Error: log folder is empty')
msgbox(msg='Error: log folder is empty')
return
run_cmd = [f'{TENSORBOARD}', '--logdir', f'{logging_dir}']

View File

@ -1,9 +1,10 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import (
from easygui import msgbox
import subprocess
import os
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
@ -19,12 +20,12 @@ def verify_lora(
):
# verify for caption_text_input
if lora_model == '':
show_message_box('Invalid model A file')
msgbox('Invalid model A file')
return
# verify if source model exist
if not os.path.isfile(lora_model):
show_message_box('The provided model A is not a file')
msgbox('The provided model A is not a file')
return
run_cmd = [
@ -68,7 +69,7 @@ def gradio_verify_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_lora_model_file.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
inputs=[lora_model, lora_ext, lora_ext_name],
outputs=lora_model,
show_progress=False,

View File

@ -1,9 +1,8 @@
import os
import subprocess
import gradio as gr
from .common_gui_functions import get_folder_path
from easygui import msgbox
import subprocess
from .common_gui import get_folder_path
import os
def replace_underscore_with_space(folder_path, file_extension):
@ -21,16 +20,16 @@ def caption_images(
):
# Check for caption_text_input
# if caption_text_input == "":
# show_message_box("Caption text is missing...")
# msgbox("Caption text is missing...")
# return
# Check for images_dir_input
if train_data_dir == '':
show_message_box('Image folder is missing...')
msgbox('Image folder is missing...')
return
if caption_extension == '':
show_message_box('Please provide an extension for the caption files.')
msgbox('Please provide an extension for the caption files.')
return
print(f'Captioning files in {train_data_dir}...')

View File

@ -3,16 +3,15 @@
# v3: Add new Utilities tab for Dreambooth folder preparation
# v3.1: Adding captionning of images to utilities
import argparse
import gradio as gr
import easygui
import json
import math
import os
import pathlib
import subprocess
import gradio as gr
from library.common_gui_functions import (
import pathlib
import argparse
from library.common_gui import (
get_folder_path,
remove_doublequote,
get_file_path,
@ -28,23 +27,24 @@ from library.common_gui_functions import (
run_cmd_training,
# set_legacy_8bitadam,
update_my_data,
check_if_model_exist, show_message_box,
check_if_model_exist,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.merge_lora_gui import gradio_merge_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
from library.tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
stop_tensorboard,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab
from library.merge_lora_gui import gradio_merge_lora_tab
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
from library.verify_lora_gui import gradio_verify_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from easygui import msgbox
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@ -359,35 +359,35 @@ def train_model(
print_only_bool = True if print_only.get('label') == 'True' else False
if pretrained_model_name_or_path == '':
show_message_box('Source model information is missing')
msgbox('Source model information is missing')
return
if train_data_dir == '':
show_message_box('Image folder path is missing')
msgbox('Image folder path is missing')
return
if not os.path.exists(train_data_dir):
show_message_box('Image folder does not exist')
msgbox('Image folder does not exist')
return
if reg_data_dir != '':
if not os.path.exists(reg_data_dir):
show_message_box('Regularisation folder does not exist')
msgbox('Regularisation folder does not exist')
return
if output_dir == '':
show_message_box('Output folder path is missing')
msgbox('Output folder path is missing')
return
if int(bucket_reso_steps) < 1:
show_message_box('Bucket resolution steps need to be greater than 0')
msgbox('Bucket resolution steps need to be greater than 0')
return
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if stop_text_encoder_training_pct > 0:
show_message_box(
msgbox(
'Output "stop text encoder training" is not yet supported. Ignoring'
)
stop_text_encoder_training_pct = 0
@ -402,7 +402,7 @@ def train_model(
unet_lr = 0
# if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
# show_message_box(
# msgbox(
# 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
# )
# return
@ -540,7 +540,7 @@ def train_model(
run_cmd += f' --network_train_unet_only'
else:
if float(text_encoder_lr) == 0:
show_message_box('Please input learning rate values.')
msgbox('Please input learning rate values.')
return
run_cmd += f' --network_dim={network_dim}'
@ -1031,14 +1031,14 @@ def lora_tab(
]
button_open_config.click(
lambda *args, **kwargs: open_configuration(),
open_configuration,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False,
)
button_load_config.click(
lambda *args, **kwargs: open_configuration(),
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False,

View File

@ -3,16 +3,14 @@
# v3: Add new Utilities tab for Dreambooth folder preparation
# v3.1: Adding captionning of images to utilities
import argparse
import gradio as gr
import json
import math
import os
import pathlib
import subprocess
import gradio as gr
from library.common_gui_functions import (
import pathlib
import argparse
from library.common_gui import (
get_folder_path,
remove_doublequote,
get_file_path,
@ -30,16 +28,17 @@ from library.common_gui_functions import (
update_my_data,
check_if_model_exist,
)
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from library.tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
stop_tensorboard,
)
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.utilities import utilities_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from easygui import msgbox
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@ -242,7 +241,7 @@ def open_configuration(
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and file_path is not None:
if not file_path == '' and not file_path == None:
# load variables from JSON file
with open(file_path, 'r') as f:
my_data = json.load(f)
@ -330,32 +329,32 @@ def train_model(
min_snr_gamma,
):
if pretrained_model_name_or_path == '':
show_message_box('Source model information is missing')
msgbox('Source model information is missing')
return
if train_data_dir == '':
show_message_box('Image folder path is missing')
msgbox('Image folder path is missing')
return
if not os.path.exists(train_data_dir):
show_message_box('Image folder does not exist')
msgbox('Image folder does not exist')
return
if reg_data_dir != '':
if not os.path.exists(reg_data_dir):
show_message_box('Regularisation folder does not exist')
msgbox('Regularisation folder does not exist')
return
if output_dir == '':
show_message_box('Output folder path is missing')
msgbox('Output folder path is missing')
return
if token_string == '':
show_message_box('Token string is missing')
msgbox('Token string is missing')
return
if init_word == '':
show_message_box('Init word is missing')
msgbox('Init word is missing')
return
if not os.path.exists(output_dir):
@ -673,7 +672,7 @@ def ti_tab(
)
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
weights_file_input.click(
lambda *args, **kwargs: get_file_path(*args),
get_file_path,
outputs=weights,
show_progress=False,
)
@ -899,14 +898,14 @@ def ti_tab(
]
button_open_config.click(
lambda *args, **kwargs: open_configuration(),
open_configuration,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_load_config.click(
lambda *args, **kwargs: open_configuration(),
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,

View File

@ -1,25 +1,28 @@
# DreamBooth training
# XXX dropped option: fine_tune
import argparse
import gc
import time
import argparse
import itertools
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util
from library.config_ml_util import (
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight

View File

@ -1,21 +1,24 @@
import importlib
import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util
from library.config_ml_util import (
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
imagenet_templates_small = [