Merge branch 'macos_gui' of https://github.com/jstayco/kohya_ss into jstayco-macos_gui
This commit is contained in:
commit
c21962fa82
239
.gitignore
vendored
239
.gitignore
vendored
@ -1,12 +1,243 @@
|
|||||||
venv
|
# Kohya_SS Specifics
|
||||||
__pycache__
|
|
||||||
cudnn_windows
|
cudnn_windows
|
||||||
.vscode
|
.vscode
|
||||||
*.egg-info
|
|
||||||
build
|
|
||||||
wd14_tagger_model
|
wd14_tagger_model
|
||||||
.DS_Store
|
.DS_Store
|
||||||
locon
|
locon
|
||||||
gui-user.bat
|
gui-user.bat
|
||||||
gui-user.ps1
|
gui-user.ps1
|
||||||
|
*.whl*
|
||||||
|
.idea
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||||
|
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||||
|
|
||||||
|
# User-specific stuff
|
||||||
|
.idea/**/workspace.xml
|
||||||
|
.idea/**/tasks.xml
|
||||||
|
.idea/**/usage.statistics.xml
|
||||||
|
.idea/**/dictionaries
|
||||||
|
.idea/**/shelf
|
||||||
|
|
||||||
|
# AWS User-specific
|
||||||
|
.idea/**/aws.xml
|
||||||
|
|
||||||
|
# Generated files
|
||||||
|
.idea/**/contentModel.xml
|
||||||
|
|
||||||
|
# Sensitive or high-churn files
|
||||||
|
.idea/**/dataSources/
|
||||||
|
.idea/**/dataSources.ids
|
||||||
|
.idea/**/dataSources.local.xml
|
||||||
|
.idea/**/sqlDataSources.xml
|
||||||
|
.idea/**/dynamic.xml
|
||||||
|
.idea/**/uiDesigner.xml
|
||||||
|
.idea/**/dbnavigator.xml
|
||||||
|
|
||||||
|
# Gradle
|
||||||
|
.idea/**/gradle.xml
|
||||||
|
.idea/**/libraries
|
||||||
|
|
||||||
|
# Gradle and Maven with auto-import
|
||||||
|
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||||
|
# since they will be recreated, and may cause churn. Uncomment if using
|
||||||
|
# auto-import.
|
||||||
|
# .idea/artifacts
|
||||||
|
# .idea/compiler.xml
|
||||||
|
# .idea/jarRepositories.xml
|
||||||
|
# .idea/modules.xml
|
||||||
|
# .idea/*.iml
|
||||||
|
# .idea/modules
|
||||||
|
# *.iml
|
||||||
|
# *.ipr
|
||||||
|
|
||||||
|
# CMake
|
||||||
|
cmake-build-*/
|
||||||
|
|
||||||
|
# Mongo Explorer plugin
|
||||||
|
.idea/**/mongoSettings.xml
|
||||||
|
|
||||||
|
# File-based project format
|
||||||
|
*.iws
|
||||||
|
|
||||||
|
# IntelliJ
|
||||||
|
out/
|
||||||
|
|
||||||
|
# mpeltonen/sbt-idea plugin
|
||||||
|
.idea_modules/
|
||||||
|
|
||||||
|
# JIRA plugin
|
||||||
|
atlassian-ide-plugin.xml
|
||||||
|
|
||||||
|
# Cursive Clojure plugin
|
||||||
|
.idea/replstate.xml
|
||||||
|
|
||||||
|
# SonarLint plugin
|
||||||
|
.idea/sonarlint/
|
||||||
|
|
||||||
|
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||||
|
com_crashlytics_export_strings.xml
|
||||||
|
crashlytics.properties
|
||||||
|
crashlytics-build.properties
|
||||||
|
fabric.properties
|
||||||
|
|
||||||
|
# Editor-based Rest Client
|
||||||
|
.idea/httpRequests
|
||||||
|
|
||||||
|
# Android studio 3.1+ serialized cache file
|
||||||
|
.idea/caches/build_file_checksums.ser
|
||||||
library/__init__.py
|
library/__init__.py
|
||||||
|
@ -3,14 +3,17 @@
|
|||||||
# v3: Add new Utilities tab for Dreambooth folder preparation
|
# v3: Add new Utilities tab for Dreambooth folder preparation
|
||||||
# v3.1: Adding captionning of images to utilities
|
# v3.1: Adding captionning of images to utilities
|
||||||
|
|
||||||
import gradio as gr
|
import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import argparse
|
import subprocess
|
||||||
from library.common_gui import (
|
import sys
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from library.common_gui_functions import (
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
@ -26,89 +29,89 @@ from library.common_gui import (
|
|||||||
gradio_source_model,
|
gradio_source_model,
|
||||||
# set_legacy_8bitadam,
|
# set_legacy_8bitadam,
|
||||||
update_my_data,
|
update_my_data,
|
||||||
check_if_model_exist,
|
check_if_model_exist, show_message_box,
|
||||||
)
|
)
|
||||||
|
from library.common_utilities import CommonUtilities
|
||||||
|
from library.dreambooth_folder_creation_gui import (
|
||||||
|
gradio_dreambooth_folder_creation_tab,
|
||||||
|
)
|
||||||
|
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||||
from library.tensorboard_gui import (
|
from library.tensorboard_gui import (
|
||||||
gradio_tensorboard,
|
gradio_tensorboard,
|
||||||
start_tensorboard,
|
start_tensorboard,
|
||||||
stop_tensorboard,
|
stop_tensorboard,
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
|
||||||
gradio_dreambooth_folder_creation_tab,
|
|
||||||
)
|
|
||||||
from library.utilities import utilities_tab
|
from library.utilities import utilities_tab
|
||||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
|
||||||
from easygui import msgbox
|
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
|
|
||||||
def save_configuration(
|
def save_configuration(
|
||||||
save_as,
|
save_as,
|
||||||
file_path,
|
file_path,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
train_data_dir,
|
train_data_dir,
|
||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
mixed_precision,
|
mixed_precision,
|
||||||
save_precision,
|
save_precision,
|
||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
no_token_padding,
|
no_token_padding,
|
||||||
stop_text_encoder_training,
|
stop_text_encoder_training,
|
||||||
# use_8bit_adam,
|
# use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as,
|
save_model_as,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
sample_every_n_steps,
|
sample_every_n_steps,
|
||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
additional_parameters,
|
additional_parameters,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -122,12 +125,12 @@ def save_configuration(
|
|||||||
file_path = get_saveasfile_path(file_path)
|
file_path = get_saveasfile_path(file_path)
|
||||||
else:
|
else:
|
||||||
print('Save...')
|
print('Save...')
|
||||||
if file_path == None or file_path == '':
|
if file_path is None or file_path == '':
|
||||||
file_path = get_saveasfile_path(file_path)
|
file_path = get_saveasfile_path(file_path)
|
||||||
|
|
||||||
# print(file_path)
|
# print(file_path)
|
||||||
|
|
||||||
if file_path == None or file_path == '':
|
if file_path is None or file_path == '':
|
||||||
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
|
|
||||||
# Return the values of the variables as a dictionary
|
# Return the values of the variables as a dictionary
|
||||||
@ -135,10 +138,10 @@ def save_configuration(
|
|||||||
name: value
|
name: value
|
||||||
for name, value in parameters # locals().items()
|
for name, value in parameters # locals().items()
|
||||||
if name
|
if name
|
||||||
not in [
|
not in [
|
||||||
'file_path',
|
'file_path',
|
||||||
'save_as',
|
'save_as',
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extract the destination directory from the file path
|
# Extract the destination directory from the file path
|
||||||
@ -156,69 +159,73 @@ def save_configuration(
|
|||||||
|
|
||||||
|
|
||||||
def open_configuration(
|
def open_configuration(
|
||||||
ask_for_file,
|
ask_for_file,
|
||||||
file_path,
|
file_path,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
train_data_dir,
|
train_data_dir,
|
||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
mixed_precision,
|
mixed_precision,
|
||||||
save_precision,
|
save_precision,
|
||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
no_token_padding,
|
no_token_padding,
|
||||||
stop_text_encoder_training,
|
stop_text_encoder_training,
|
||||||
# use_8bit_adam,
|
# use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as,
|
save_model_as,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
sample_every_n_steps,
|
sample_every_n_steps,
|
||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
additional_parameters,
|
additional_parameters,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
):
|
):
|
||||||
|
print("open_configuration called")
|
||||||
|
print(f"locals length: {len(locals())}")
|
||||||
|
print(f"locals: {locals()}")
|
||||||
|
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
|
|
||||||
@ -226,18 +233,25 @@ def open_configuration(
|
|||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
if ask_for_file:
|
if ask_for_file and file_path is not None:
|
||||||
file_path = get_file_path(file_path)
|
print(f"File path: {file_path}")
|
||||||
|
file_path, canceled = get_file_path(file_path=file_path, filedialog_type="json")
|
||||||
|
|
||||||
if not file_path == '' and not file_path == None:
|
if canceled:
|
||||||
# load variables from JSON file
|
return (None,) + (None,) * (len(parameters) - 2)
|
||||||
|
|
||||||
|
if not file_path == '' and file_path is not None:
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data = json.load(f)
|
my_data = json.load(f)
|
||||||
print('Loading config...')
|
if CommonUtilities.is_valid_config(my_data):
|
||||||
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
print('Loading config...')
|
||||||
my_data = update_my_data(my_data)
|
my_data = update_my_data(my_data)
|
||||||
|
else:
|
||||||
|
print("Invalid configuration file.")
|
||||||
|
my_data = {}
|
||||||
|
show_message_box("Invalid configuration file.")
|
||||||
else:
|
else:
|
||||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
file_path = original_file_path
|
||||||
my_data = {}
|
my_data = {}
|
||||||
|
|
||||||
values = [file_path]
|
values = [file_path]
|
||||||
@ -245,90 +259,92 @@ def open_configuration(
|
|||||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||||
if not key in ['ask_for_file', 'file_path']:
|
if not key in ['ask_for_file', 'file_path']:
|
||||||
values.append(my_data.get(key, value))
|
values.append(my_data.get(key, value))
|
||||||
|
# Print the number of returned values
|
||||||
|
print(f"Returning: {values}")
|
||||||
return tuple(values)
|
return tuple(values)
|
||||||
|
|
||||||
|
|
||||||
def train_model(
|
def train_model(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
train_data_dir,
|
train_data_dir,
|
||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
mixed_precision,
|
mixed_precision,
|
||||||
save_precision,
|
save_precision,
|
||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
no_token_padding,
|
no_token_padding,
|
||||||
stop_text_encoder_training_pct,
|
stop_text_encoder_training_pct,
|
||||||
# use_8bit_adam,
|
# use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as,
|
save_model_as,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
sample_every_n_steps,
|
sample_every_n_steps,
|
||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
additional_parameters,
|
additional_parameters,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
show_message_box('Source model information is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if train_data_dir == '':
|
if train_data_dir == '':
|
||||||
msgbox('Image folder path is missing')
|
show_message_box('Image folder path is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.exists(train_data_dir):
|
if not os.path.exists(train_data_dir):
|
||||||
msgbox('Image folder does not exist')
|
show_message_box('Image folder does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
if reg_data_dir != '':
|
if reg_data_dir != '':
|
||||||
if not os.path.exists(reg_data_dir):
|
if not os.path.exists(reg_data_dir):
|
||||||
msgbox('Regularisation folder does not exist')
|
show_message_box('Regularisation folder does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
if output_dir == '':
|
if output_dir == '':
|
||||||
msgbox('Output folder path is missing')
|
show_message_box('Output folder path is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if check_if_model_exist(output_name, output_dir, save_model_as):
|
if check_if_model_exist(output_name, output_dir, save_model_as):
|
||||||
@ -339,7 +355,7 @@ def train_model(
|
|||||||
f
|
f
|
||||||
for f in os.listdir(train_data_dir)
|
for f in os.listdir(train_data_dir)
|
||||||
if os.path.isdir(os.path.join(train_data_dir, f))
|
if os.path.isdir(os.path.join(train_data_dir, f))
|
||||||
and not f.startswith('.')
|
and not f.startswith('.')
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check if subfolders are present. If not let the user know and return
|
# Check if subfolders are present. If not let the user know and return
|
||||||
@ -371,11 +387,11 @@ def train_model(
|
|||||||
[
|
[
|
||||||
f
|
f
|
||||||
for f, lower_f in (
|
for f, lower_f in (
|
||||||
(file, file.lower())
|
(file, file.lower())
|
||||||
for file in os.listdir(
|
for file in os.listdir(
|
||||||
os.path.join(train_data_dir, folder)
|
os.path.join(train_data_dir, folder)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -546,10 +562,10 @@ def train_model(
|
|||||||
|
|
||||||
|
|
||||||
def dreambooth_tab(
|
def dreambooth_tab(
|
||||||
train_data_dir=gr.Textbox(),
|
train_data_dir=gr.Textbox(),
|
||||||
reg_data_dir=gr.Textbox(),
|
reg_data_dir=gr.Textbox(),
|
||||||
output_dir=gr.Textbox(),
|
output_dir=gr.Textbox(),
|
||||||
logging_dir=gr.Textbox(),
|
logging_dir=gr.Textbox(),
|
||||||
):
|
):
|
||||||
dummy_db_true = gr.Label(value=True, visible=False)
|
dummy_db_true = gr.Label(value=True, visible=False)
|
||||||
dummy_db_false = gr.Label(value=False, visible=False)
|
dummy_db_false = gr.Label(value=False, visible=False)
|
||||||
@ -832,19 +848,20 @@ def dreambooth_tab(
|
|||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
open_configuration,
|
lambda *_args, **kwargs: open_configuration(*_args, **kwargs),
|
||||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list,
|
outputs=[config_file_name] + settings_list,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_load_config.click(
|
button_load_config.click(
|
||||||
open_configuration,
|
lambda *args, **kwargs: (print("Lambda called"), open_configuration(*args, **kwargs)),
|
||||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list,
|
outputs=[config_file_name] + settings_list,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
# Print the number of expected outputs
|
||||||
|
print(f"Number of expected outputs: {len([config_file_name] + settings_list)}")
|
||||||
button_save_config.click(
|
button_save_config.click(
|
||||||
save_configuration,
|
save_configuration,
|
||||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||||
|
10
fine_tune.py
10
fine_tune.py
@ -5,22 +5,20 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
from library.config_ml_util import (
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
import gradio as gr
|
import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import argparse
|
import subprocess
|
||||||
from library.common_gui import (
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from library.common_gui_functions import (
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
save_inference_file,
|
save_inference_file,
|
||||||
gradio_advanced_training,
|
gradio_advanced_training,
|
||||||
run_cmd_advanced_training,
|
|
||||||
gradio_training,
|
gradio_training,
|
||||||
run_cmd_advanced_training,
|
run_cmd_advanced_training,
|
||||||
gradio_config,
|
gradio_config,
|
||||||
@ -22,13 +23,13 @@ from library.common_gui import (
|
|||||||
update_my_data,
|
update_my_data,
|
||||||
check_if_model_exist,
|
check_if_model_exist,
|
||||||
)
|
)
|
||||||
|
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||||
from library.tensorboard_gui import (
|
from library.tensorboard_gui import (
|
||||||
gradio_tensorboard,
|
gradio_tensorboard,
|
||||||
start_tensorboard,
|
start_tensorboard,
|
||||||
stop_tensorboard,
|
stop_tensorboard,
|
||||||
)
|
)
|
||||||
from library.utilities import utilities_tab
|
from library.utilities import utilities_tab
|
||||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -233,7 +234,7 @@ def open_configuration(
|
|||||||
if ask_for_file:
|
if ask_for_file:
|
||||||
file_path = get_file_path(file_path)
|
file_path = get_file_path(file_path)
|
||||||
|
|
||||||
if not file_path == '' and not file_path == None:
|
if not file_path == '' and file_path is not None:
|
||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data = json.load(f)
|
my_data = json.load(f)
|
||||||
@ -799,14 +800,14 @@ def finetune_tab():
|
|||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
open_configuration,
|
lambda *args, **kwargs: open_configuration(),
|
||||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list,
|
outputs=[config_file_name] + settings_list,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_load_config.click(
|
button_load_config.click(
|
||||||
open_configuration,
|
lambda *args, **kwargs: open_configuration(),
|
||||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list,
|
outputs=[config_file_name] + settings_list,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
13
kohya_gui.py
13
kohya_gui.py
@ -1,17 +1,18 @@
|
|||||||
import gradio as gr
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from dreambooth_gui import dreambooth_tab
|
from dreambooth_gui import dreambooth_tab
|
||||||
from finetune_gui import finetune_tab
|
from finetune_gui import finetune_tab
|
||||||
from textual_inversion_gui import ti_tab
|
|
||||||
from library.utilities import utilities_tab
|
|
||||||
from library.extract_lora_gui import gradio_extract_lora_tab
|
from library.extract_lora_gui import gradio_extract_lora_tab
|
||||||
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
||||||
from library.merge_lora_gui import gradio_merge_lora_tab
|
from library.merge_lora_gui import gradio_merge_lora_tab
|
||||||
from library.resize_lora_gui import gradio_resize_lora_tab
|
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||||
|
from library.utilities import utilities_tab
|
||||||
from lora_gui import lora_tab
|
from lora_gui import lora_tab
|
||||||
|
from textual_inversion_gui import ti_tab
|
||||||
|
|
||||||
def UI(**kwargs):
|
def UI(**kwargs):
|
||||||
css = ''
|
css = ''
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
from .common_gui import get_folder_path, add_pre_postfix, find_replace
|
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from .common_gui_functions import get_folder_path, add_pre_postfix, find_replace
|
||||||
|
|
||||||
|
|
||||||
def caption_images(
|
def caption_images(
|
||||||
@ -17,11 +18,11 @@ def caption_images(
|
|||||||
):
|
):
|
||||||
# Check for images_dir
|
# Check for images_dir
|
||||||
if not images_dir:
|
if not images_dir:
|
||||||
msgbox('Image folder is missing...')
|
show_message_box('Image folder is missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not caption_ext:
|
if not caption_ext:
|
||||||
msgbox('Please provide an extension for the caption files.')
|
show_message_box('Please provide an extension for the caption files.')
|
||||||
return
|
return
|
||||||
|
|
||||||
if caption_text:
|
if caption_text:
|
||||||
@ -60,7 +61,7 @@ def caption_images(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if prefix or postfix:
|
if prefix or postfix:
|
||||||
msgbox(
|
show_message_box(
|
||||||
'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
|
'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
from .common_gui import get_folder_path, add_pre_postfix
|
import subprocess
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from .common_gui_functions import get_folder_path, add_pre_postfix
|
||||||
|
|
||||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
|
|
||||||
@ -21,16 +22,16 @@ def caption_images(
|
|||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
# if caption_text_input == "":
|
# if caption_text_input == "":
|
||||||
# msgbox("Caption text is missing...")
|
# show_message_box("Caption text is missing...")
|
||||||
# return
|
# return
|
||||||
|
|
||||||
# Check for images_dir_input
|
# Check for images_dir_input
|
||||||
if train_data_dir == '':
|
if train_data_dir == '':
|
||||||
msgbox('Image folder is missing...')
|
show_message_box('Image folder is missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
if caption_file_ext == '':
|
if caption_file_ext == '':
|
||||||
msgbox('Please provide an extension for the caption files.')
|
show_message_box('Please provide an extension for the caption files.')
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f'Captioning files in {train_data_dir}...')
|
print(f'Captioning files in {train_data_dir}...')
|
||||||
|
@ -1,14 +1,19 @@
|
|||||||
from tkinter import filedialog, Tk
|
|
||||||
from easygui import msgbox
|
|
||||||
import os
|
import os
|
||||||
import gradio as gr
|
|
||||||
import easygui
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import tkinter as tk
|
||||||
|
from tkinter import filedialog, Tk
|
||||||
|
|
||||||
|
import easygui
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from library.common_utilities import CommonUtilities
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
# define a list of substrings to search for v2 base models
|
# define a list of substrings to search for v2 base models
|
||||||
V2_BASE_MODELS = [
|
V2_BASE_MODELS = [
|
||||||
@ -34,6 +39,41 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
|
|||||||
FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
|
FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def tk_context():
|
||||||
|
root = tk.Tk()
|
||||||
|
root.withdraw()
|
||||||
|
try:
|
||||||
|
yield root
|
||||||
|
finally:
|
||||||
|
root.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
def open_file_dialog(initial_dir, initial_file, file_types="all"):
|
||||||
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
args = ["python", f"{current_directory}/gui_subprocesses.py", "file_dialog"]
|
||||||
|
if initial_dir:
|
||||||
|
args.append(initial_dir)
|
||||||
|
if initial_file:
|
||||||
|
args.append(initial_file)
|
||||||
|
if file_types:
|
||||||
|
args.append(file_types)
|
||||||
|
|
||||||
|
file_path = subprocess.check_output(args).decode("utf-8").strip()
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def show_message_box(message, title=""):
|
||||||
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
args = ["python", f"{current_directory}/gui_subprocesses.py", "msgbox", message]
|
||||||
|
if title:
|
||||||
|
args.append(title)
|
||||||
|
|
||||||
|
subprocess.run(args)
|
||||||
|
|
||||||
|
|
||||||
def check_if_model_exist(output_name, output_dir, save_model_as):
|
def check_if_model_exist(output_name, output_dir, save_model_as):
|
||||||
if save_model_as in ['diffusers', 'diffusers_safetendors']:
|
if save_model_as in ['diffusers', 'diffusers_safetendors']:
|
||||||
ckpt_folder = os.path.join(output_dir, output_name)
|
ckpt_folder = os.path.join(output_dir, output_name)
|
||||||
@ -87,8 +127,8 @@ def update_my_data(my_data):
|
|||||||
|
|
||||||
# Update model save choices due to changes for LoRA and TI training
|
# Update model save choices due to changes for LoRA and TI training
|
||||||
if (
|
if (
|
||||||
(my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
|
(my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
|
||||||
and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
|
and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
|
||||||
):
|
):
|
||||||
message = (
|
message = (
|
||||||
'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
|
'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
|
||||||
@ -102,11 +142,6 @@ def update_my_data(my_data):
|
|||||||
return my_data
|
return my_data
|
||||||
|
|
||||||
|
|
||||||
def get_dir_and_file(file_path):
|
|
||||||
dir_path, file_name = os.path.split(file_path)
|
|
||||||
return (dir_path, file_name)
|
|
||||||
|
|
||||||
|
|
||||||
# def has_ext_files(directory, extension):
|
# def has_ext_files(directory, extension):
|
||||||
# # Iterate through all the files in the directory
|
# # Iterate through all the files in the directory
|
||||||
# for file in os.listdir(directory):
|
# for file in os.listdir(directory):
|
||||||
@ -117,67 +152,38 @@ def get_dir_and_file(file_path):
|
|||||||
# return False
|
# return False
|
||||||
|
|
||||||
|
|
||||||
def get_file_path(
|
def get_file_path(file_path, initial_dir=None, initial_file=None, filedialog_type="lora"):
|
||||||
file_path='', default_extension='.json', extension_name='Config files'
|
file_extension = os.path.splitext(file_path)[-1].lower()
|
||||||
):
|
|
||||||
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
|
||||||
current_file_path = file_path
|
|
||||||
# print(f'current file path: {current_file_path}')
|
|
||||||
|
|
||||||
initial_dir, initial_file = get_dir_and_file(file_path)
|
# Find the appropriate filedialog_type based on the file extension
|
||||||
|
for key, extensions in CommonUtilities.file_filters.items():
|
||||||
|
if file_extension in extensions:
|
||||||
|
filedialog_type = key
|
||||||
|
break
|
||||||
|
|
||||||
# Create a hidden Tkinter root window
|
current_file_path = file_path
|
||||||
root = Tk()
|
|
||||||
root.wm_attributes('-topmost', 1)
|
|
||||||
root.withdraw()
|
|
||||||
|
|
||||||
# Show the open file dialog and get the selected file path
|
initial_dir, initial_file = os.path.split(file_path)
|
||||||
file_path = filedialog.askopenfilename(
|
result = open_file_dialog(initial_dir=initial_dir, initial_file=initial_file, file_types=filedialog_type)
|
||||||
filetypes=(
|
file_path, canceled = result[:2]
|
||||||
(extension_name, f'*{default_extension}'),
|
return file_path, canceled
|
||||||
('All files', '*.*'),
|
|
||||||
),
|
|
||||||
defaultextension=default_extension,
|
|
||||||
initialfile=initial_file,
|
|
||||||
initialdir=initial_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Destroy the hidden root window
|
|
||||||
root.destroy()
|
|
||||||
|
|
||||||
# If no file is selected, use the current file path
|
|
||||||
if not file_path:
|
|
||||||
file_path = current_file_path
|
|
||||||
current_file_path = file_path
|
|
||||||
# print(f'current file path: {current_file_path}')
|
|
||||||
|
|
||||||
return file_path
|
|
||||||
|
|
||||||
|
|
||||||
def get_any_file_path(file_path=''):
|
def get_any_file_path(file_path=''):
|
||||||
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
current_file_path = file_path
|
||||||
current_file_path = file_path
|
# print(f'current file path: {current_file_path}')
|
||||||
# print(f'current file path: {current_file_path}')
|
|
||||||
|
|
||||||
initial_dir, initial_file = get_dir_and_file(file_path)
|
initial_dir, initial_file = os.path.split(file_path)
|
||||||
|
file_path = open_file_dialog(initial_dir, initial_file, "all")
|
||||||
|
|
||||||
root = Tk()
|
if file_path == '':
|
||||||
root.wm_attributes('-topmost', 1)
|
file_path = current_file_path
|
||||||
root.withdraw()
|
|
||||||
file_path = filedialog.askopenfilename(
|
|
||||||
initialdir=initial_dir,
|
|
||||||
initialfile=initial_file,
|
|
||||||
)
|
|
||||||
root.destroy()
|
|
||||||
|
|
||||||
if file_path == '':
|
|
||||||
file_path = current_file_path
|
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
def remove_doublequote(file_path):
|
def remove_doublequote(file_path):
|
||||||
if file_path != None:
|
if file_path is not None:
|
||||||
file_path = file_path.replace('"', '')
|
file_path = file_path.replace('"', '')
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
@ -196,62 +202,37 @@ def remove_doublequote(file_path):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
def get_folder_path(folder_path=''):
|
def get_folder_path(folder_path='', filedialog_type="directory"):
|
||||||
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
current_folder_path = folder_path
|
||||||
current_folder_path = folder_path
|
|
||||||
|
|
||||||
initial_dir, initial_file = get_dir_and_file(folder_path)
|
initial_dir, initial_file = os.path.split(folder_path)
|
||||||
|
file_path = open_file_dialog(initial_dir, initial_file, filedialog_type)
|
||||||
|
|
||||||
root = Tk()
|
if folder_path == '':
|
||||||
root.wm_attributes('-topmost', 1)
|
folder_path = current_folder_path
|
||||||
root.withdraw()
|
|
||||||
folder_path = filedialog.askdirectory(initialdir=initial_dir)
|
|
||||||
root.destroy()
|
|
||||||
|
|
||||||
if folder_path == '':
|
|
||||||
folder_path = current_folder_path
|
|
||||||
|
|
||||||
return folder_path
|
return folder_path
|
||||||
|
|
||||||
|
|
||||||
def get_saveasfile_path(
|
def get_saveasfile_path(
|
||||||
file_path='', defaultextension='.json', extension_name='Config files'
|
file_path='', filedialog_type="json"
|
||||||
):
|
):
|
||||||
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
current_file_path = file_path
|
||||||
current_file_path = file_path
|
|
||||||
# print(f'current file path: {current_file_path}')
|
|
||||||
|
|
||||||
initial_dir, initial_file = get_dir_and_file(file_path)
|
initial_dir, initial_file = os.path.split(file_path)
|
||||||
|
save_file_path = save_file_dialog(initial_dir, initial_file, filedialog_type)
|
||||||
|
|
||||||
root = Tk()
|
if save_file_path is None:
|
||||||
root.wm_attributes('-topmost', 1)
|
file_path = current_file_path
|
||||||
root.withdraw()
|
else:
|
||||||
save_file_path = filedialog.asksaveasfile(
|
print(save_file_path.name)
|
||||||
filetypes=(
|
file_path = save_file_path.name
|
||||||
(f'{extension_name}', f'{defaultextension}'),
|
|
||||||
('All files', '*'),
|
|
||||||
),
|
|
||||||
defaultextension=defaultextension,
|
|
||||||
initialdir=initial_dir,
|
|
||||||
initialfile=initial_file,
|
|
||||||
)
|
|
||||||
root.destroy()
|
|
||||||
|
|
||||||
# print(save_file_path)
|
|
||||||
|
|
||||||
if save_file_path == None:
|
|
||||||
file_path = current_file_path
|
|
||||||
else:
|
|
||||||
print(save_file_path.name)
|
|
||||||
file_path = save_file_path.name
|
|
||||||
|
|
||||||
# print(file_path)
|
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
def get_saveasfilename_path(
|
def get_saveasfilename_path(
|
||||||
file_path='', extensions='*', extension_name='Config files'
|
file_path='', extensions='*', extension_name='Config files'
|
||||||
):
|
):
|
||||||
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
@ -280,10 +261,10 @@ def get_saveasfilename_path(
|
|||||||
|
|
||||||
|
|
||||||
def add_pre_postfix(
|
def add_pre_postfix(
|
||||||
folder: str = '',
|
folder: str = '',
|
||||||
prefix: str = '',
|
prefix: str = '',
|
||||||
postfix: str = '',
|
postfix: str = '',
|
||||||
caption_file_ext: str = '.caption',
|
caption_file_ext: str = '.caption',
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add prefix and/or postfix to the content of caption files within a folder.
|
Add prefix and/or postfix to the content of caption files within a folder.
|
||||||
@ -343,10 +324,10 @@ def has_ext_files(folder_path: str, file_extension: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def find_replace(
|
def find_replace(
|
||||||
folder_path: str = '',
|
folder_path: str = '',
|
||||||
caption_file_ext: str = '.caption',
|
caption_file_ext: str = '.caption',
|
||||||
search_text: str = '',
|
search_text: str = '',
|
||||||
replace_text: str = '',
|
replace_text: str = '',
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Find and replace text in caption files within a folder.
|
Find and replace text in caption files within a folder.
|
||||||
@ -360,7 +341,7 @@ def find_replace(
|
|||||||
print('Running caption find/replace')
|
print('Running caption find/replace')
|
||||||
|
|
||||||
if not has_ext_files(folder_path, caption_file_ext):
|
if not has_ext_files(folder_path, caption_file_ext):
|
||||||
msgbox(
|
show_message_box(
|
||||||
f'No files with extension {caption_file_ext} were found in {folder_path}...'
|
f'No files with extension {caption_file_ext} were found in {folder_path}...'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -374,7 +355,7 @@ def find_replace(
|
|||||||
|
|
||||||
for caption_file in caption_files:
|
for caption_file in caption_files:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(folder_path, caption_file), 'r', errors='ignore'
|
os.path.join(folder_path, caption_file), 'r', errors='ignore'
|
||||||
) as f:
|
) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
@ -386,7 +367,7 @@ def find_replace(
|
|||||||
|
|
||||||
def color_aug_changed(color_aug):
|
def color_aug_changed(color_aug):
|
||||||
if color_aug:
|
if color_aug:
|
||||||
msgbox(
|
show_message_box(
|
||||||
'Disabling "Cache latent" because "Color augmentation" has been selected...'
|
'Disabling "Cache latent" because "Color augmentation" has been selected...'
|
||||||
)
|
)
|
||||||
return gr.Checkbox.update(value=False, interactive=False)
|
return gr.Checkbox.update(value=False, interactive=False)
|
||||||
@ -427,7 +408,7 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
|||||||
|
|
||||||
|
|
||||||
def set_pretrained_model_name_or_path_input(
|
def set_pretrained_model_name_or_path_input(
|
||||||
model_list, pretrained_model_name_or_path, v2, v_parameterization
|
model_list, pretrained_model_name_or_path, v2, v_parameterization
|
||||||
):
|
):
|
||||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||||
if str(model_list) in V2_BASE_MODELS:
|
if str(model_list) in V2_BASE_MODELS:
|
||||||
@ -452,9 +433,9 @@ def set_pretrained_model_name_or_path_input(
|
|||||||
|
|
||||||
if model_list == 'custom':
|
if model_list == 'custom':
|
||||||
if (
|
if (
|
||||||
str(pretrained_model_name_or_path) in V1_MODELS
|
str(pretrained_model_name_or_path) in V1_MODELS
|
||||||
or str(pretrained_model_name_or_path) in V2_BASE_MODELS
|
or str(pretrained_model_name_or_path) in V2_BASE_MODELS
|
||||||
or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS
|
or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS
|
||||||
):
|
):
|
||||||
pretrained_model_name_or_path = ''
|
pretrained_model_name_or_path = ''
|
||||||
v2 = False
|
v2 = False
|
||||||
@ -481,12 +462,11 @@ def set_v2_checkbox(model_list, v2, v_parameterization):
|
|||||||
|
|
||||||
|
|
||||||
def set_model_list(
|
def set_model_list(
|
||||||
model_list,
|
model_list,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not pretrained_model_name_or_path in ALL_PRESET_MODELS:
|
if not pretrained_model_name_or_path in ALL_PRESET_MODELS:
|
||||||
model_list = 'custom'
|
model_list = 'custom'
|
||||||
else:
|
else:
|
||||||
@ -529,7 +509,7 @@ def gradio_config():
|
|||||||
|
|
||||||
|
|
||||||
def get_pretrained_model_name_or_path_file(
|
def get_pretrained_model_name_or_path_file(
|
||||||
model_list, pretrained_model_name_or_path
|
model_list, pretrained_model_name_or_path
|
||||||
):
|
):
|
||||||
pretrained_model_name_or_path = get_any_file_path(
|
pretrained_model_name_or_path = get_any_file_path(
|
||||||
pretrained_model_name_or_path
|
pretrained_model_name_or_path
|
||||||
@ -537,13 +517,13 @@ def get_pretrained_model_name_or_path_file(
|
|||||||
set_model_list(model_list, pretrained_model_name_or_path)
|
set_model_list(model_list, pretrained_model_name_or_path)
|
||||||
|
|
||||||
|
|
||||||
def gradio_source_model(save_model_as_choices = [
|
def gradio_source_model(save_model_as_choices=[
|
||||||
'same as source model',
|
'same as source model',
|
||||||
'ckpt',
|
'ckpt',
|
||||||
'diffusers',
|
'diffusers',
|
||||||
'diffusers_safetensors',
|
'diffusers_safetensors',
|
||||||
'safetensors',
|
'safetensors',
|
||||||
]):
|
]):
|
||||||
with gr.Tab('Source model'):
|
with gr.Tab('Source model'):
|
||||||
# Define the input elements
|
# Define the input elements
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -648,9 +628,9 @@ def gradio_source_model(save_model_as_choices = [
|
|||||||
|
|
||||||
|
|
||||||
def gradio_training(
|
def gradio_training(
|
||||||
learning_rate_value='1e-6',
|
learning_rate_value='1e-6',
|
||||||
lr_scheduler_value='constant',
|
lr_scheduler_value='constant',
|
||||||
lr_warmup_value='0',
|
lr_warmup_value='0',
|
||||||
):
|
):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_batch_size = gr.Slider(
|
train_batch_size = gr.Slider(
|
||||||
@ -840,7 +820,7 @@ def gradio_advanced_training():
|
|||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
||||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1)
|
min_snr_gamma = gr.Slider(label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
bucket_no_upscale = gr.Checkbox(
|
bucket_no_upscale = gr.Checkbox(
|
||||||
label="Don't upscale bucket resolution", value=True
|
label="Don't upscale bucket resolution", value=True
|
24
library/common_utilities.py
Normal file
24
library/common_utilities.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
class CommonUtilities:
|
||||||
|
file_filters = {
|
||||||
|
"all": [("All files", "*.*")],
|
||||||
|
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")],
|
||||||
|
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")],
|
||||||
|
"json": [("JSON files", "*.json")],
|
||||||
|
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
|
||||||
|
"directory": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def is_valid_config(self, data):
|
||||||
|
# Check if the data is a dictionary
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Add checks for expected keys and valid values
|
||||||
|
# For example, check if 'use_8bit_adam' is a boolean
|
||||||
|
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Add more checks for other keys as needed
|
||||||
|
|
||||||
|
# If all checks pass, return True
|
||||||
|
return True
|
@ -1,13 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
import random
|
||||||
from dataclasses import (
|
from dataclasses import (
|
||||||
asdict,
|
asdict,
|
||||||
dataclass,
|
dataclass,
|
||||||
)
|
)
|
||||||
import functools
|
|
||||||
import random
|
|
||||||
from textwrap import dedent, indent
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from textwrap import dedent, indent
|
||||||
# from toolz import curry
|
# from toolz import curry
|
||||||
from typing import (
|
from typing import (
|
||||||
List,
|
List,
|
||||||
@ -19,6 +19,7 @@ from typing import (
|
|||||||
|
|
||||||
import toml
|
import toml
|
||||||
import voluptuous
|
import voluptuous
|
||||||
|
from transformers import CLIPTokenizer
|
||||||
from voluptuous import (
|
from voluptuous import (
|
||||||
Any,
|
Any,
|
||||||
ExactSequence,
|
ExactSequence,
|
||||||
@ -27,7 +28,6 @@ from voluptuous import (
|
|||||||
Required,
|
Required,
|
||||||
Schema,
|
Schema,
|
||||||
)
|
)
|
||||||
from transformers import CLIPTokenizer
|
|
||||||
|
|
||||||
from . import train_util
|
from . import train_util
|
||||||
from .train_util import (
|
from .train_util import (
|
@ -1,28 +1,29 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from .common_gui import get_folder_path, get_file_path
|
import subprocess
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from .common_gui_functions import get_folder_path, get_file_path
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
|
|
||||||
|
|
||||||
def convert_model(
|
def convert_model(
|
||||||
source_model_input,
|
source_model_input,
|
||||||
source_model_type,
|
source_model_type,
|
||||||
target_model_folder_input,
|
target_model_folder_input,
|
||||||
target_model_name_input,
|
target_model_name_input,
|
||||||
target_model_type,
|
target_model_type,
|
||||||
target_save_precision_type,
|
target_save_precision_type,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if source_model_type == '':
|
if source_model_type == '':
|
||||||
msgbox('Invalid source model type')
|
show_message_box('Invalid source model type')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
@ -31,14 +32,14 @@ def convert_model(
|
|||||||
elif os.path.isdir(source_model_input):
|
elif os.path.isdir(source_model_input):
|
||||||
print('The provided model is a folder')
|
print('The provided model is a folder')
|
||||||
else:
|
else:
|
||||||
msgbox('The provided source model is neither a file nor a folder')
|
show_message_box('The provided source model is neither a file nor a folder')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
if os.path.isdir(target_model_folder_input):
|
if os.path.isdir(target_model_folder_input):
|
||||||
print('The provided model folder exist')
|
print('The provided model folder exist')
|
||||||
else:
|
else:
|
||||||
msgbox('The provided target folder does not exist')
|
show_message_box('The provided target folder does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd = f'{PYTHON} "tools/convert_diffusers20_original_sd.py"'
|
run_cmd = f'{PYTHON} "tools/convert_diffusers20_original_sd.py"'
|
||||||
@ -60,8 +61,8 @@ def convert_model(
|
|||||||
run_cmd += f' --{target_save_precision_type}'
|
run_cmd += f' --{target_save_precision_type}'
|
||||||
|
|
||||||
if (
|
if (
|
||||||
target_model_type == 'diffuser'
|
target_model_type == 'diffuser'
|
||||||
or target_model_type == 'diffuser_safetensors'
|
or target_model_type == 'diffuser_safetensors'
|
||||||
):
|
):
|
||||||
run_cmd += f' --reference_model="{source_model_type}"'
|
run_cmd += f' --reference_model="{source_model_type}"'
|
||||||
|
|
||||||
@ -71,8 +72,8 @@ def convert_model(
|
|||||||
run_cmd += f' "{source_model_input}"'
|
run_cmd += f' "{source_model_input}"'
|
||||||
|
|
||||||
if (
|
if (
|
||||||
target_model_type == 'diffuser'
|
target_model_type == 'diffuser'
|
||||||
or target_model_type == 'diffuser_safetensors'
|
or target_model_type == 'diffuser_safetensors'
|
||||||
):
|
):
|
||||||
target_model_path = os.path.join(
|
target_model_path = os.path.join(
|
||||||
target_model_folder_input, target_model_name_input
|
target_model_folder_input, target_model_name_input
|
||||||
@ -94,8 +95,8 @@ def convert_model(
|
|||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not target_model_type == 'diffuser'
|
not target_model_type == 'diffuser'
|
||||||
or target_model_type == 'diffuser_safetensors'
|
or target_model_type == 'diffuser_safetensors'
|
||||||
):
|
):
|
||||||
|
|
||||||
v2_models = [
|
v2_models = [
|
||||||
@ -179,7 +180,7 @@ def gradio_convert_model_tab():
|
|||||||
document_symbol, elem_id='open_folder_small'
|
document_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_source_model_file.click(
|
button_source_model_file.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[source_model_input],
|
inputs=[source_model_input],
|
||||||
outputs=source_model_input,
|
outputs=source_model_input,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from easygui import msgbox, boolbox
|
from easygui import boolbox
|
||||||
from .common_gui import get_folder_path
|
|
||||||
|
from .common_gui_functions import get_folder_path
|
||||||
|
|
||||||
|
|
||||||
# def select_folder():
|
# def select_folder():
|
||||||
# # Open a file dialog to select a directory
|
# # Open a file dialog to select a directory
|
||||||
@ -16,14 +19,14 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
|||||||
|
|
||||||
if not concept_repeats > 0:
|
if not concept_repeats > 0:
|
||||||
# Display an error message if the total number of repeats is not a valid integer
|
# Display an error message if the total number of repeats is not a valid integer
|
||||||
msgbox('Please enter a valid integer for the total number of repeats.')
|
show_message_box('Please enter a valid integer for the total number of repeats.')
|
||||||
return
|
return
|
||||||
|
|
||||||
concept_repeats = int(concept_repeats)
|
concept_repeats = int(concept_repeats)
|
||||||
|
|
||||||
# Check if folder exist
|
# Check if folder exist
|
||||||
if folder == '' or not os.path.isdir(folder):
|
if folder == '' or not os.path.isdir(folder):
|
||||||
msgbox('Please enter a valid folder for balancing.')
|
show_message_box('Please enter a valid folder for balancing.')
|
||||||
return
|
return
|
||||||
|
|
||||||
pattern = re.compile(r'^\d+_.+$')
|
pattern = re.compile(r'^\d+_.+$')
|
||||||
@ -85,7 +88,7 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
|||||||
f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...'
|
f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...'
|
||||||
)
|
)
|
||||||
|
|
||||||
msgbox('Dataset balancing completed...')
|
show_message_box('Dataset balancing completed...')
|
||||||
|
|
||||||
|
|
||||||
def warning(insecure):
|
def warning(insecure):
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import diropenbox, msgbox
|
|
||||||
from .common_gui import get_folder_path
|
|
||||||
import shutil
|
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from .common_gui_functions import get_folder_path
|
||||||
|
|
||||||
|
|
||||||
def copy_info_to_Folders_tab(training_folder):
|
def copy_info_to_Folders_tab(training_folder):
|
||||||
@ -39,12 +40,12 @@ def dreambooth_folder_preparation(
|
|||||||
|
|
||||||
# Check for instance prompt
|
# Check for instance prompt
|
||||||
if util_instance_prompt_input == '':
|
if util_instance_prompt_input == '':
|
||||||
msgbox('Instance prompt missing...')
|
show_message_box('Instance prompt missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check for class prompt
|
# Check for class prompt
|
||||||
if util_class_prompt_input == '':
|
if util_class_prompt_input == '':
|
||||||
msgbox('Class prompt missing...')
|
show_message_box('Class prompt missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create the training_dir path
|
# Create the training_dir path
|
||||||
|
@ -1,50 +1,49 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
from .common_gui import (
|
import subprocess
|
||||||
get_saveasfilename_path,
|
|
||||||
get_any_file_path,
|
import gradio as gr
|
||||||
get_file_path,
|
|
||||||
|
from .common_gui_functions import (
|
||||||
|
get_file_path, get_saveasfile_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
|
|
||||||
|
|
||||||
def extract_lora(
|
def extract_lora(
|
||||||
model_tuned,
|
model_tuned,
|
||||||
model_org,
|
model_org,
|
||||||
save_to,
|
save_to,
|
||||||
save_precision,
|
save_precision,
|
||||||
dim,
|
dim,
|
||||||
v2,
|
v2,
|
||||||
conv_dim,
|
conv_dim,
|
||||||
device,
|
device,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if model_tuned == '':
|
if model_tuned == '':
|
||||||
msgbox('Invalid finetuned model file')
|
show_message_box('Invalid finetuned model file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if model_org == '':
|
if model_org == '':
|
||||||
msgbox('Invalid base model file')
|
show_message_box('Invalid base model file')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
if not os.path.isfile(model_tuned):
|
if not os.path.isfile(model_tuned):
|
||||||
msgbox('The provided finetuned model is not a file')
|
show_message_box('The provided finetuned model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.isfile(model_org):
|
if not os.path.isfile(model_org):
|
||||||
msgbox('The provided base model is not a file')
|
show_message_box('The provided base model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd = (
|
run_cmd = (
|
||||||
f'{PYTHON} "{os.path.join("networks","extract_lora_from_models.py")}"'
|
f'{PYTHON} "{os.path.join("networks", "extract_lora_from_models.py")}"'
|
||||||
)
|
)
|
||||||
run_cmd += f' --save_precision {save_precision}'
|
run_cmd += f' --save_precision {save_precision}'
|
||||||
run_cmd += f' --save_to "{save_to}"'
|
run_cmd += f' --save_to "{save_to}"'
|
||||||
@ -91,7 +90,7 @@ def gradio_extract_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_model_tuned_file.click(
|
button_model_tuned_file.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[model_tuned, model_ext, model_ext_name],
|
inputs=[model_tuned, model_ext, model_ext_name],
|
||||||
outputs=model_tuned,
|
outputs=model_tuned,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -106,7 +105,8 @@ def gradio_extract_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_model_org_file.click(
|
button_model_org_file.click(
|
||||||
get_file_path,
|
lambda input1, input2, input3, *args, **kwargs:
|
||||||
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[model_org, model_ext, model_ext_name],
|
inputs=[model_org, model_ext, model_ext_name],
|
||||||
outputs=model_org,
|
outputs=model_org,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -121,7 +121,7 @@ def gradio_extract_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_save_to.click(
|
button_save_to.click(
|
||||||
get_saveasfilename_path,
|
get_saveasfile_path,
|
||||||
inputs=[save_to, lora_ext, lora_ext_name],
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
outputs=save_to,
|
outputs=save_to,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
from .common_gui import (
|
import subprocess
|
||||||
get_saveasfilename_path,
|
|
||||||
get_any_file_path,
|
import gradio as gr
|
||||||
get_file_path,
|
|
||||||
|
from .common_gui_functions import (
|
||||||
|
get_file_path, get_saveasfile_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -36,20 +35,20 @@ def extract_lycoris_locon(
|
|||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if db_model == '':
|
if db_model == '':
|
||||||
msgbox('Invalid finetuned model file')
|
show_message_box('Invalid finetuned model file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if base_model == '':
|
if base_model == '':
|
||||||
msgbox('Invalid base model file')
|
show_message_box('Invalid base model file')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
if not os.path.isfile(db_model):
|
if not os.path.isfile(db_model):
|
||||||
msgbox('The provided finetuned model is not a file')
|
show_message_box('The provided finetuned model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.isfile(base_model):
|
if not os.path.isfile(base_model):
|
||||||
msgbox('The provided base model is not a file')
|
show_message_box('The provided base model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
|
run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
|
||||||
@ -137,7 +136,8 @@ def gradio_extract_lycoris_locon_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_db_model_file.click(
|
button_db_model_file.click(
|
||||||
get_file_path,
|
lambda input1, input2, input3, *args, **kwargs:
|
||||||
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[db_model, model_ext, model_ext_name],
|
inputs=[db_model, model_ext, model_ext_name],
|
||||||
outputs=db_model,
|
outputs=db_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -152,7 +152,7 @@ def gradio_extract_lycoris_locon_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_base_model_file.click(
|
button_base_model_file.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[base_model, model_ext, model_ext_name],
|
inputs=[base_model, model_ext, model_ext_name],
|
||||||
outputs=base_model,
|
outputs=base_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -167,7 +167,7 @@ def gradio_extract_lycoris_locon_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_output_name.click(
|
button_output_name.click(
|
||||||
get_saveasfilename_path,
|
get_saveasfile_path,
|
||||||
inputs=[output_name, lora_ext, lora_ext_name],
|
inputs=[output_name, lora_ext, lora_ext_name],
|
||||||
outputs=output_name,
|
outputs=output_name,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
from .common_gui import get_folder_path, add_pre_postfix
|
import subprocess
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from .common_gui_functions import get_folder_path, add_pre_postfix
|
||||||
|
|
||||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
|
|
||||||
@ -19,11 +20,11 @@ def caption_images(
|
|||||||
):
|
):
|
||||||
# Check for images_dir_input
|
# Check for images_dir_input
|
||||||
if train_data_dir == '':
|
if train_data_dir == '':
|
||||||
msgbox('Image folder is missing...')
|
show_message_box('Image folder is missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
if caption_ext == '':
|
if caption_ext == '':
|
||||||
msgbox('Please provide an extension for the caption files.')
|
show_message_box('Please provide an extension for the caption files.')
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f'GIT captioning files in {train_data_dir}...')
|
print(f'GIT captioning files in {train_data_dir}...')
|
||||||
|
86
library/gui_subprocesses.py
Normal file
86
library/gui_subprocesses.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
import tkinter as tk
|
||||||
|
from tkinter import filedialog, messagebox
|
||||||
|
|
||||||
|
from library.common_gui_functions import tk_context
|
||||||
|
from library.common_utilities import CommonUtilities
|
||||||
|
|
||||||
|
|
||||||
|
class TkGui:
|
||||||
|
def __init__(self):
|
||||||
|
self.file_types = None
|
||||||
|
|
||||||
|
def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"):
|
||||||
|
with tk_context():
|
||||||
|
self.file_types = file_types
|
||||||
|
if self.file_types in CommonUtilities.file_filters:
|
||||||
|
filters = CommonUtilities.file_filters[self.file_types]
|
||||||
|
else:
|
||||||
|
filters = CommonUtilities.file_filters["all"]
|
||||||
|
|
||||||
|
if self.file_types == "directory":
|
||||||
|
result = filedialog.askdirectory(initialdir=initial_dir)
|
||||||
|
else:
|
||||||
|
result = filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
|
||||||
|
|
||||||
|
# Return a tuple (file_path, canceled)
|
||||||
|
# file_path: the selected file path or an empty string if no file is selected
|
||||||
|
# canceled: True if the user pressed the cancel button, False otherwise
|
||||||
|
return result, result == ""
|
||||||
|
|
||||||
|
def save_file_dialog(self, initial_dir, initial_file, file_types="all"):
|
||||||
|
self.file_types = file_types
|
||||||
|
|
||||||
|
# Use the tk_context function with the 'with' statement
|
||||||
|
with tk_context():
|
||||||
|
if self.file_types in CommonUtilities.file_filters:
|
||||||
|
filters = CommonUtilities.file_filters[self.file_types]
|
||||||
|
else:
|
||||||
|
filters = CommonUtilities.file_filters["all"]
|
||||||
|
|
||||||
|
save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file,
|
||||||
|
filetypes=filters, defaultextension=".safetensors")
|
||||||
|
|
||||||
|
return save_file_path
|
||||||
|
|
||||||
|
def show_message_box(_message, _title="Message", _level="info"):
|
||||||
|
with tk_context():
|
||||||
|
message_type = {
|
||||||
|
"warning": messagebox.showwarning,
|
||||||
|
"error": messagebox.showerror,
|
||||||
|
"info": messagebox.showinfo,
|
||||||
|
"question": messagebox.askquestion,
|
||||||
|
"okcancel": messagebox.askokcancel,
|
||||||
|
"retrycancel": messagebox.askretrycancel,
|
||||||
|
"yesno": messagebox.askyesno,
|
||||||
|
"yesnocancel": messagebox.askyesnocancel
|
||||||
|
}
|
||||||
|
|
||||||
|
if _level in message_type:
|
||||||
|
message_type[_level](_title, _message)
|
||||||
|
else:
|
||||||
|
messagebox.showinfo(_title, _message)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
try:
|
||||||
|
mode = sys.argv[1]
|
||||||
|
|
||||||
|
if mode == 'file_dialog':
|
||||||
|
starting_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||||||
|
starting_file = sys.argv[3] if len(sys.argv) > 3 else None
|
||||||
|
file_class = sys.argv[4] if len(sys.argv) > 4 else None # Update this to sys.argv[4]
|
||||||
|
gui = TkGui()
|
||||||
|
file_path = gui.open_file_dialog(starting_dir, starting_file, file_class)
|
||||||
|
print(file_path) # Make sure to print the result
|
||||||
|
|
||||||
|
elif mode == 'msgbox':
|
||||||
|
message = sys.argv[2]
|
||||||
|
title = sys.argv[3] if len(sys.argv) > 3 else ""
|
||||||
|
gui = TkGui()
|
||||||
|
gui.show_message_box(message, title)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
@ -1,11 +1,10 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
from .common_gui import (
|
import subprocess
|
||||||
get_saveasfilename_path,
|
|
||||||
get_any_file_path,
|
import gradio as gr
|
||||||
get_file_path,
|
|
||||||
|
from .common_gui_functions import (
|
||||||
|
get_file_path, get_saveasfile_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -25,20 +24,20 @@ def merge_lora(
|
|||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if lora_a_model == '':
|
if lora_a_model == '':
|
||||||
msgbox('Invalid model A file')
|
show_message_box('Invalid model A file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if lora_b_model == '':
|
if lora_b_model == '':
|
||||||
msgbox('Invalid model B file')
|
show_message_box('Invalid model B file')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
if not os.path.isfile(lora_a_model):
|
if not os.path.isfile(lora_a_model):
|
||||||
msgbox('The provided model A is not a file')
|
show_message_box('The provided model A is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.isfile(lora_b_model):
|
if not os.path.isfile(lora_b_model):
|
||||||
msgbox('The provided model B is not a file')
|
show_message_box('The provided model B is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
ratio_a = ratio
|
ratio_a = ratio
|
||||||
@ -82,7 +81,7 @@ def gradio_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_a_model_file.click(
|
button_lora_a_model_file.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_a_model,
|
outputs=lora_a_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -97,7 +96,7 @@ def gradio_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_b_model_file.click(
|
button_lora_b_model_file.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_b_model,
|
outputs=lora_b_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -122,7 +121,7 @@ def gradio_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_save_to.click(
|
button_save_to.click(
|
||||||
get_saveasfilename_path,
|
get_saveasfile_path,
|
||||||
inputs=[save_to, lora_ext, lora_ext_name],
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
outputs=save_to,
|
outputs=save_to,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
from .common_gui import get_saveasfilename_path, get_file_path
|
import subprocess
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from .common_gui_functions import get_file_path, get_saveasfile_path
|
||||||
|
|
||||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -23,24 +24,24 @@ def resize_lora(
|
|||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if model == '':
|
if model == '':
|
||||||
msgbox('Invalid model file')
|
show_message_box('Invalid model file')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
if not os.path.isfile(model):
|
if not os.path.isfile(model):
|
||||||
msgbox('The provided model is not a file')
|
show_message_box('The provided model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if dynamic_method == 'sv_ratio':
|
if dynamic_method == 'sv_ratio':
|
||||||
if float(dynamic_param) < 2:
|
if float(dynamic_param) < 2:
|
||||||
msgbox(
|
show_message_box(
|
||||||
f'Dynamic parameter for {dynamic_method} need to be 2 or greater...'
|
f'Dynamic parameter for {dynamic_method} need to be 2 or greater...'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
|
if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
|
||||||
if float(dynamic_param) < 0 or float(dynamic_param) > 1:
|
if float(dynamic_param) < 0 or float(dynamic_param) > 1:
|
||||||
msgbox(
|
show_message_box(
|
||||||
f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...'
|
f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -95,7 +96,7 @@ def gradio_resize_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_a_model_file.click(
|
button_lora_a_model_file.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[model, lora_ext, lora_ext_name],
|
inputs=[model, lora_ext, lora_ext_name],
|
||||||
outputs=model,
|
outputs=model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -134,7 +135,7 @@ def gradio_resize_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_save_to.click(
|
button_save_to.click(
|
||||||
get_saveasfilename_path,
|
get_saveasfile_path,
|
||||||
inputs=[save_to, lora_ext, lora_ext_name],
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
outputs=save_to,
|
outputs=save_to,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from easygui import msgbox
|
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
from .common_gui import (
|
import subprocess
|
||||||
get_saveasfilename_path,
|
|
||||||
get_any_file_path,
|
import gradio as gr
|
||||||
get_file_path,
|
|
||||||
|
from .common_gui_functions import (
|
||||||
|
get_file_path, get_saveasfile_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -28,20 +27,20 @@ def svd_merge_lora(
|
|||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if lora_a_model == '':
|
if lora_a_model == '':
|
||||||
msgbox('Invalid model A file')
|
show_message_box('Invalid model A file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if lora_b_model == '':
|
if lora_b_model == '':
|
||||||
msgbox('Invalid model B file')
|
show_message_box('Invalid model B file')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
if not os.path.isfile(lora_a_model):
|
if not os.path.isfile(lora_a_model):
|
||||||
msgbox('The provided model A is not a file')
|
show_message_box('The provided model A is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.isfile(lora_b_model):
|
if not os.path.isfile(lora_b_model):
|
||||||
msgbox('The provided model B is not a file')
|
show_message_box('The provided model B is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
ratio_a = ratio
|
ratio_a = ratio
|
||||||
@ -88,7 +87,7 @@ def gradio_svd_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_a_model_file.click(
|
button_lora_a_model_file.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_a_model,
|
outputs=lora_a_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -103,7 +102,7 @@ def gradio_svd_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_b_model_file.click(
|
button_lora_b_model_file.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_b_model,
|
outputs=lora_b_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -144,7 +143,7 @@ def gradio_svd_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_save_to.click(
|
button_save_to.click(
|
||||||
get_saveasfilename_path,
|
get_saveasfile_path,
|
||||||
inputs=[save_to, lora_ext, lora_ext_name],
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
outputs=save_to,
|
outputs=save_to,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
tensorboard_proc = None # I know... bad but heh
|
tensorboard_proc = None # I know... bad but heh
|
||||||
TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe'
|
TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe'
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ def start_tensorboard(logging_dir):
|
|||||||
|
|
||||||
if not os.listdir(logging_dir):
|
if not os.listdir(logging_dir):
|
||||||
print('Error: log folder is empty')
|
print('Error: log folder is empty')
|
||||||
msgbox(msg='Error: log folder is empty')
|
show_message_box(msg='Error: log folder is empty')
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd = [f'{TENSORBOARD}', '--logdir', f'{logging_dir}']
|
run_cmd = [f'{TENSORBOARD}', '--logdir', f'{logging_dir}']
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
from .common_gui import (
|
import subprocess
|
||||||
get_saveasfilename_path,
|
|
||||||
get_any_file_path,
|
import gradio as gr
|
||||||
|
|
||||||
|
from .common_gui_functions import (
|
||||||
get_file_path,
|
get_file_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,12 +19,12 @@ def verify_lora(
|
|||||||
):
|
):
|
||||||
# verify for caption_text_input
|
# verify for caption_text_input
|
||||||
if lora_model == '':
|
if lora_model == '':
|
||||||
msgbox('Invalid model A file')
|
show_message_box('Invalid model A file')
|
||||||
return
|
return
|
||||||
|
|
||||||
# verify if source model exist
|
# verify if source model exist
|
||||||
if not os.path.isfile(lora_model):
|
if not os.path.isfile(lora_model):
|
||||||
msgbox('The provided model A is not a file')
|
show_message_box('The provided model A is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd = [
|
run_cmd = [
|
||||||
@ -69,7 +68,7 @@ def gradio_verify_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_model_file.click(
|
button_lora_model_file.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[lora_model, lora_ext, lora_ext_name],
|
inputs=[lora_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_model,
|
outputs=lora_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import gradio as gr
|
|
||||||
from easygui import msgbox
|
|
||||||
import subprocess
|
|
||||||
from .common_gui import get_folder_path
|
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from .common_gui_functions import get_folder_path
|
||||||
|
|
||||||
|
|
||||||
def replace_underscore_with_space(folder_path, file_extension):
|
def replace_underscore_with_space(folder_path, file_extension):
|
||||||
@ -20,16 +21,16 @@ def caption_images(
|
|||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
# if caption_text_input == "":
|
# if caption_text_input == "":
|
||||||
# msgbox("Caption text is missing...")
|
# show_message_box("Caption text is missing...")
|
||||||
# return
|
# return
|
||||||
|
|
||||||
# Check for images_dir_input
|
# Check for images_dir_input
|
||||||
if train_data_dir == '':
|
if train_data_dir == '':
|
||||||
msgbox('Image folder is missing...')
|
show_message_box('Image folder is missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
if caption_extension == '':
|
if caption_extension == '':
|
||||||
msgbox('Please provide an extension for the caption files.')
|
show_message_box('Please provide an extension for the caption files.')
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f'Captioning files in {train_data_dir}...')
|
print(f'Captioning files in {train_data_dir}...')
|
||||||
|
46
lora_gui.py
46
lora_gui.py
@ -3,15 +3,16 @@
|
|||||||
# v3: Add new Utilities tab for Dreambooth folder preparation
|
# v3: Add new Utilities tab for Dreambooth folder preparation
|
||||||
# v3.1: Adding captionning of images to utilities
|
# v3.1: Adding captionning of images to utilities
|
||||||
|
|
||||||
import gradio as gr
|
import argparse
|
||||||
import easygui
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import argparse
|
import subprocess
|
||||||
from library.common_gui import (
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from library.common_gui_functions import (
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
@ -27,24 +28,23 @@ from library.common_gui import (
|
|||||||
run_cmd_training,
|
run_cmd_training,
|
||||||
# set_legacy_8bitadam,
|
# set_legacy_8bitadam,
|
||||||
update_my_data,
|
update_my_data,
|
||||||
check_if_model_exist,
|
check_if_model_exist, show_message_box,
|
||||||
)
|
)
|
||||||
|
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
)
|
)
|
||||||
|
from library.merge_lora_gui import gradio_merge_lora_tab
|
||||||
|
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||||
|
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||||
|
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
|
||||||
from library.tensorboard_gui import (
|
from library.tensorboard_gui import (
|
||||||
gradio_tensorboard,
|
gradio_tensorboard,
|
||||||
start_tensorboard,
|
start_tensorboard,
|
||||||
stop_tensorboard,
|
stop_tensorboard,
|
||||||
)
|
)
|
||||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
|
||||||
from library.utilities import utilities_tab
|
from library.utilities import utilities_tab
|
||||||
from library.merge_lora_gui import gradio_merge_lora_tab
|
|
||||||
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
|
|
||||||
from library.verify_lora_gui import gradio_verify_lora_tab
|
from library.verify_lora_gui import gradio_verify_lora_tab
|
||||||
from library.resize_lora_gui import gradio_resize_lora_tab
|
|
||||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
|
||||||
from easygui import msgbox
|
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -359,35 +359,35 @@ def train_model(
|
|||||||
print_only_bool = True if print_only.get('label') == 'True' else False
|
print_only_bool = True if print_only.get('label') == 'True' else False
|
||||||
|
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
show_message_box('Source model information is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if train_data_dir == '':
|
if train_data_dir == '':
|
||||||
msgbox('Image folder path is missing')
|
show_message_box('Image folder path is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.exists(train_data_dir):
|
if not os.path.exists(train_data_dir):
|
||||||
msgbox('Image folder does not exist')
|
show_message_box('Image folder does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
if reg_data_dir != '':
|
if reg_data_dir != '':
|
||||||
if not os.path.exists(reg_data_dir):
|
if not os.path.exists(reg_data_dir):
|
||||||
msgbox('Regularisation folder does not exist')
|
show_message_box('Regularisation folder does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
if output_dir == '':
|
if output_dir == '':
|
||||||
msgbox('Output folder path is missing')
|
show_message_box('Output folder path is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if int(bucket_reso_steps) < 1:
|
if int(bucket_reso_steps) < 1:
|
||||||
msgbox('Bucket resolution steps need to be greater than 0')
|
show_message_box('Bucket resolution steps need to be greater than 0')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
if stop_text_encoder_training_pct > 0:
|
if stop_text_encoder_training_pct > 0:
|
||||||
msgbox(
|
show_message_box(
|
||||||
'Output "stop text encoder training" is not yet supported. Ignoring'
|
'Output "stop text encoder training" is not yet supported. Ignoring'
|
||||||
)
|
)
|
||||||
stop_text_encoder_training_pct = 0
|
stop_text_encoder_training_pct = 0
|
||||||
@ -402,7 +402,7 @@ def train_model(
|
|||||||
unet_lr = 0
|
unet_lr = 0
|
||||||
|
|
||||||
# if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
|
# if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
|
||||||
# msgbox(
|
# show_message_box(
|
||||||
# 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
|
# 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
|
||||||
# )
|
# )
|
||||||
# return
|
# return
|
||||||
@ -540,7 +540,7 @@ def train_model(
|
|||||||
run_cmd += f' --network_train_unet_only'
|
run_cmd += f' --network_train_unet_only'
|
||||||
else:
|
else:
|
||||||
if float(text_encoder_lr) == 0:
|
if float(text_encoder_lr) == 0:
|
||||||
msgbox('Please input learning rate values.')
|
show_message_box('Please input learning rate values.')
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd += f' --network_dim={network_dim}'
|
run_cmd += f' --network_dim={network_dim}'
|
||||||
@ -1031,14 +1031,14 @@ def lora_tab(
|
|||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
open_configuration,
|
lambda *args, **kwargs: open_configuration(),
|
||||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list + [LoCon_row],
|
outputs=[config_file_name] + settings_list + [LoCon_row],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_load_config.click(
|
button_load_config.click(
|
||||||
open_configuration,
|
lambda *args, **kwargs: open_configuration(),
|
||||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list + [LoCon_row],
|
outputs=[config_file_name] + settings_list + [LoCon_row],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -3,14 +3,16 @@
|
|||||||
# v3: Add new Utilities tab for Dreambooth folder preparation
|
# v3: Add new Utilities tab for Dreambooth folder preparation
|
||||||
# v3.1: Adding captionning of images to utilities
|
# v3.1: Adding captionning of images to utilities
|
||||||
|
|
||||||
import gradio as gr
|
import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import argparse
|
import subprocess
|
||||||
from library.common_gui import (
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from library.common_gui_functions import (
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
@ -28,17 +30,16 @@ from library.common_gui import (
|
|||||||
update_my_data,
|
update_my_data,
|
||||||
check_if_model_exist,
|
check_if_model_exist,
|
||||||
)
|
)
|
||||||
|
from library.dreambooth_folder_creation_gui import (
|
||||||
|
gradio_dreambooth_folder_creation_tab,
|
||||||
|
)
|
||||||
|
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||||
from library.tensorboard_gui import (
|
from library.tensorboard_gui import (
|
||||||
gradio_tensorboard,
|
gradio_tensorboard,
|
||||||
start_tensorboard,
|
start_tensorboard,
|
||||||
stop_tensorboard,
|
stop_tensorboard,
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
|
||||||
gradio_dreambooth_folder_creation_tab,
|
|
||||||
)
|
|
||||||
from library.utilities import utilities_tab
|
from library.utilities import utilities_tab
|
||||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
|
||||||
from easygui import msgbox
|
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -241,7 +242,7 @@ def open_configuration(
|
|||||||
if ask_for_file:
|
if ask_for_file:
|
||||||
file_path = get_file_path(file_path)
|
file_path = get_file_path(file_path)
|
||||||
|
|
||||||
if not file_path == '' and not file_path == None:
|
if not file_path == '' and file_path is not None:
|
||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data = json.load(f)
|
my_data = json.load(f)
|
||||||
@ -329,32 +330,32 @@ def train_model(
|
|||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
show_message_box('Source model information is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if train_data_dir == '':
|
if train_data_dir == '':
|
||||||
msgbox('Image folder path is missing')
|
show_message_box('Image folder path is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.exists(train_data_dir):
|
if not os.path.exists(train_data_dir):
|
||||||
msgbox('Image folder does not exist')
|
show_message_box('Image folder does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
if reg_data_dir != '':
|
if reg_data_dir != '':
|
||||||
if not os.path.exists(reg_data_dir):
|
if not os.path.exists(reg_data_dir):
|
||||||
msgbox('Regularisation folder does not exist')
|
show_message_box('Regularisation folder does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
if output_dir == '':
|
if output_dir == '':
|
||||||
msgbox('Output folder path is missing')
|
show_message_box('Output folder path is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if token_string == '':
|
if token_string == '':
|
||||||
msgbox('Token string is missing')
|
show_message_box('Token string is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if init_word == '':
|
if init_word == '':
|
||||||
msgbox('Init word is missing')
|
show_message_box('Init word is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
@ -672,7 +673,7 @@ def ti_tab(
|
|||||||
)
|
)
|
||||||
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
||||||
weights_file_input.click(
|
weights_file_input.click(
|
||||||
get_file_path,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
outputs=weights,
|
outputs=weights,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
@ -898,14 +899,14 @@ def ti_tab(
|
|||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
open_configuration,
|
lambda *args, **kwargs: open_configuration(),
|
||||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list,
|
outputs=[config_file_name] + settings_list,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_load_config.click(
|
button_load_config.click(
|
||||||
open_configuration,
|
lambda *args, **kwargs: open_configuration(),
|
||||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list,
|
outputs=[config_file_name] + settings_list,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
13
train_db.py
13
train_db.py
@ -1,28 +1,25 @@
|
|||||||
# DreamBooth training
|
# DreamBooth training
|
||||||
# XXX dropped option: fine_tune
|
# XXX dropped option: fine_tune
|
||||||
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
from library.config_ml_util import (
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,24 +1,21 @@
|
|||||||
import importlib
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
from library.config_ml_util import (
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
|
Loading…
Reference in New Issue
Block a user