All sorts of broken, but I need to commit this for now so I don't lose it. WIP: Using some OOP to reduce imports and centralize some code. No need to remake Tk Windows everywhere for example. Renamed common_gui to common_gui_functions.py to make some of the new code separation more obvious.
This commit is contained in:
parent
e5b83df675
commit
eef5becab8
@ -9,10 +9,11 @@ import math
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from library.common_gui import (
|
from library.common_gui_functions import (
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
@ -28,9 +29,9 @@ from library.common_gui import (
|
|||||||
gradio_source_model,
|
gradio_source_model,
|
||||||
# set_legacy_8bitadam,
|
# set_legacy_8bitadam,
|
||||||
update_my_data,
|
update_my_data,
|
||||||
check_if_model_exist, show_message_box, get_file_path_gradio_wrapper,
|
check_if_model_exist, show_message_box,
|
||||||
)
|
)
|
||||||
from library.common_utilities import is_valid_config
|
from library.common_utilities import CommonUtilities
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
)
|
)
|
||||||
@ -230,12 +231,12 @@ def open_configuration(
|
|||||||
|
|
||||||
if ask_for_file:
|
if ask_for_file:
|
||||||
print(f"File path: {file_path}")
|
print(f"File path: {file_path}")
|
||||||
file_path = get_file_path_gradio_wrapper(file_path)
|
file_path = get_file_path(file_path, filedialog_type="json")
|
||||||
|
|
||||||
if not file_path == '' and file_path is not None:
|
if not file_path == '' and file_path is not None:
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data = json.load(f)
|
my_data = json.load(f)
|
||||||
if is_valid_config(my_data):
|
if CommonUtilities.is_valid_config(my_data):
|
||||||
print('Loading config...')
|
print('Loading config...')
|
||||||
my_data = update_my_data(my_data)
|
my_data = update_my_data(my_data)
|
||||||
else:
|
else:
|
||||||
@ -838,14 +839,14 @@ def dreambooth_tab(
|
|||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
lambda *args, **kwargs: open_configuration(*args),
|
lambda *_args, **kwargs: open_configuration(*_args, **kwargs),
|
||||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list,
|
outputs=[config_file_name] + settings_list,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_load_config.click(
|
button_load_config.click(
|
||||||
lambda *args, **kwargs: open_configuration(*args),
|
lambda *args, **kwargs: open_configuration(*args, **kwargs),
|
||||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list,
|
outputs=[config_file_name] + settings_list,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
10
fine_tune.py
10
fine_tune.py
@ -5,22 +5,20 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
from library.config_ml_util import (
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
import gradio as gr
|
import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import argparse
|
import subprocess
|
||||||
from library.common_gui import (
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from library.common_gui_functions import (
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
save_inference_file,
|
save_inference_file,
|
||||||
gradio_advanced_training,
|
gradio_advanced_training,
|
||||||
run_cmd_advanced_training,
|
|
||||||
gradio_training,
|
gradio_training,
|
||||||
run_cmd_advanced_training,
|
run_cmd_advanced_training,
|
||||||
gradio_config,
|
gradio_config,
|
||||||
@ -20,15 +21,15 @@ from library.common_gui import (
|
|||||||
run_cmd_training,
|
run_cmd_training,
|
||||||
# set_legacy_8bitadam,
|
# set_legacy_8bitadam,
|
||||||
update_my_data,
|
update_my_data,
|
||||||
check_if_model_exist, get_file_path_gradio_wrapper,
|
check_if_model_exist,
|
||||||
)
|
)
|
||||||
|
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||||
from library.tensorboard_gui import (
|
from library.tensorboard_gui import (
|
||||||
gradio_tensorboard,
|
gradio_tensorboard,
|
||||||
start_tensorboard,
|
start_tensorboard,
|
||||||
stop_tensorboard,
|
stop_tensorboard,
|
||||||
)
|
)
|
||||||
from library.utilities import utilities_tab
|
from library.utilities import utilities_tab
|
||||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -231,7 +232,7 @@ def open_configuration(
|
|||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
if ask_for_file:
|
if ask_for_file:
|
||||||
file_path = get_file_path_gradio_wrapper(file_path)
|
file_path = get_file_path(file_path)
|
||||||
|
|
||||||
if not file_path == '' and file_path is not None:
|
if not file_path == '' and file_path is not None:
|
||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
|
13
kohya_gui.py
13
kohya_gui.py
@ -1,17 +1,18 @@
|
|||||||
import gradio as gr
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from dreambooth_gui import dreambooth_tab
|
from dreambooth_gui import dreambooth_tab
|
||||||
from finetune_gui import finetune_tab
|
from finetune_gui import finetune_tab
|
||||||
from textual_inversion_gui import ti_tab
|
|
||||||
from library.utilities import utilities_tab
|
|
||||||
from library.extract_lora_gui import gradio_extract_lora_tab
|
from library.extract_lora_gui import gradio_extract_lora_tab
|
||||||
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
||||||
from library.merge_lora_gui import gradio_merge_lora_tab
|
from library.merge_lora_gui import gradio_merge_lora_tab
|
||||||
from library.resize_lora_gui import gradio_resize_lora_tab
|
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||||
|
from library.utilities import utilities_tab
|
||||||
from lora_gui import lora_tab
|
from lora_gui import lora_tab
|
||||||
|
from textual_inversion_gui import ti_tab
|
||||||
|
|
||||||
def UI(**kwargs):
|
def UI(**kwargs):
|
||||||
css = ''
|
css = ''
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import get_folder_path, add_pre_postfix, find_replace
|
from .common_gui_functions import get_folder_path, add_pre_postfix, find_replace
|
||||||
|
|
||||||
|
|
||||||
def caption_images(
|
def caption_images(
|
||||||
|
@ -3,7 +3,7 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import get_folder_path, add_pre_postfix
|
from .common_gui_functions import get_folder_path, add_pre_postfix
|
||||||
|
|
||||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
|
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import tkinter as tk
|
||||||
from tkinter import filedialog, Tk
|
from tkinter import filedialog, Tk
|
||||||
|
|
||||||
import easygui
|
import easygui
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from library.gui_subprocesses import save_file_dialog
|
from library.common_utilities import CommonUtilities
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -37,6 +39,16 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
|
|||||||
FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
|
FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def tk_context():
|
||||||
|
root = tk.Tk()
|
||||||
|
root.withdraw()
|
||||||
|
try:
|
||||||
|
yield root
|
||||||
|
finally:
|
||||||
|
root.destroy()
|
||||||
|
|
||||||
|
|
||||||
def open_file_dialog(initial_dir, initial_file, file_types="all"):
|
def open_file_dialog(initial_dir, initial_file, file_types="all"):
|
||||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
@ -89,6 +101,7 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def update_my_data(my_data):
|
def update_my_data(my_data):
|
||||||
# Update the optimizer based on the use_8bit_adam flag
|
# Update the optimizer based on the use_8bit_adam flag
|
||||||
use_8bit_adam = my_data.get('use_8bit_adam', False)
|
use_8bit_adam = my_data.get('use_8bit_adam', False)
|
||||||
@ -138,28 +151,37 @@ def update_my_data(my_data):
|
|||||||
# # If no extension files were found, return False
|
# # If no extension files were found, return False
|
||||||
# return False
|
# return False
|
||||||
|
|
||||||
def get_file_path_gradio_wrapper(file_path, filedialog_type="all"):
|
# def get_file_path_gradio_wrapper(file_path, filedialog_type="all"):
|
||||||
|
# file_extension = os.path.splitext(file_path)[-1].lower()
|
||||||
|
#
|
||||||
|
# filetype_filters = {
|
||||||
|
# 'db': ['.db'],
|
||||||
|
# 'json': ['.json'],
|
||||||
|
# 'lora': ['.pt', '.ckpt', '.safetensors'],
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# # Find the appropriate filedialog_type based on the file extension
|
||||||
|
# filedialog_type = 'all'
|
||||||
|
# for key, extensions in filetype_filters.items():
|
||||||
|
# if file_extension in extensions:
|
||||||
|
# filedialog_type = key
|
||||||
|
# break
|
||||||
|
#
|
||||||
|
# return get_file_path(file_path, filedialog_type)
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_path(file_path='', filedialog_type="lora"):
|
||||||
file_extension = os.path.splitext(file_path)[-1].lower()
|
file_extension = os.path.splitext(file_path)[-1].lower()
|
||||||
|
|
||||||
filetype_filters = {
|
|
||||||
'db': ['.db'],
|
|
||||||
'json': ['.json'],
|
|
||||||
'lora': ['.pt', '.ckpt', '.safetensors'],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Find the appropriate filedialog_type based on the file extension
|
# Find the appropriate filedialog_type based on the file extension
|
||||||
filedialog_type = 'all'
|
for key, extensions in CommonUtilities.file_filters.items():
|
||||||
for key, extensions in filetype_filters.items():
|
|
||||||
if file_extension in extensions:
|
if file_extension in extensions:
|
||||||
filedialog_type = key
|
filedialog_type = key
|
||||||
break
|
break
|
||||||
|
|
||||||
return get_file_path(file_path, filedialog_type)
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_path(file_path='', filedialog_type="lora"):
|
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
|
|
||||||
|
print(f"File type: {filedialog_type}")
|
||||||
initial_dir, initial_file = os.path.split(file_path)
|
initial_dir, initial_file = os.path.split(file_path)
|
||||||
file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type)
|
file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type)
|
||||||
|
|
||||||
@ -171,7 +193,6 @@ def get_file_path(file_path='', filedialog_type="lora"):
|
|||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_any_file_path(file_path=''):
|
def get_any_file_path(file_path=''):
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
# print(f'current file path: {current_file_path}')
|
# print(f'current file path: {current_file_path}')
|
||||||
@ -823,7 +844,7 @@ def gradio_advanced_training():
|
|||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
||||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1)
|
min_snr_gamma = gr.Slider(label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
bucket_no_upscale = gr.Checkbox(
|
bucket_no_upscale = gr.Checkbox(
|
||||||
label="Don't upscale bucket resolution", value=True
|
label="Don't upscale bucket resolution", value=True
|
@ -1,14 +1,24 @@
|
|||||||
def is_valid_config(data):
|
class CommonUtilities:
|
||||||
# Check if the data is a dictionary
|
file_filters = {
|
||||||
if not isinstance(data, dict):
|
"all": [("All files", "*.*")],
|
||||||
return False
|
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")],
|
||||||
|
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")],
|
||||||
|
"json": [("JSON files", "*.json")],
|
||||||
|
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
|
||||||
|
"directory": [],
|
||||||
|
}
|
||||||
|
|
||||||
# Add checks for expected keys and valid values
|
def is_valid_config(self, data):
|
||||||
# For example, check if 'use_8bit_adam' is a boolean
|
# Check if the data is a dictionary
|
||||||
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
|
if not isinstance(data, dict):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Add more checks for other keys as needed
|
# Add checks for expected keys and valid values
|
||||||
|
# For example, check if 'use_8bit_adam' is a boolean
|
||||||
|
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
|
||||||
|
return False
|
||||||
|
|
||||||
# If all checks pass, return True
|
# Add more checks for other keys as needed
|
||||||
return True
|
|
||||||
|
# If all checks pass, return True
|
||||||
|
return True
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
import random
|
||||||
from dataclasses import (
|
from dataclasses import (
|
||||||
asdict,
|
asdict,
|
||||||
dataclass,
|
dataclass,
|
||||||
)
|
)
|
||||||
import functools
|
|
||||||
import random
|
|
||||||
from textwrap import dedent, indent
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from textwrap import dedent, indent
|
||||||
# from toolz import curry
|
# from toolz import curry
|
||||||
from typing import (
|
from typing import (
|
||||||
List,
|
List,
|
||||||
@ -19,6 +19,7 @@ from typing import (
|
|||||||
|
|
||||||
import toml
|
import toml
|
||||||
import voluptuous
|
import voluptuous
|
||||||
|
from transformers import CLIPTokenizer
|
||||||
from voluptuous import (
|
from voluptuous import (
|
||||||
Any,
|
Any,
|
||||||
ExactSequence,
|
ExactSequence,
|
||||||
@ -27,7 +28,6 @@ from voluptuous import (
|
|||||||
Required,
|
Required,
|
||||||
Schema,
|
Schema,
|
||||||
)
|
)
|
||||||
from transformers import CLIPTokenizer
|
|
||||||
|
|
||||||
from . import train_util
|
from . import train_util
|
||||||
from .train_util import (
|
from .train_util import (
|
@ -4,7 +4,7 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import get_folder_path, get_file_path, get_file_path_gradio_wrapper
|
from .common_gui_functions import get_folder_path, get_file_path
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -180,8 +180,7 @@ def gradio_convert_model_tab():
|
|||||||
document_symbol, elem_id='open_folder_small'
|
document_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_source_model_file.click(
|
button_source_model_file.click(
|
||||||
lambda input1, *args, **kwargs:
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.normpath(input1)),
|
|
||||||
inputs=[source_model_input],
|
inputs=[source_model_input],
|
||||||
outputs=source_model_input,
|
outputs=source_model_input,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -4,7 +4,7 @@ import re
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from easygui import boolbox
|
from easygui import boolbox
|
||||||
|
|
||||||
from .common_gui import get_folder_path
|
from .common_gui_functions import get_folder_path
|
||||||
|
|
||||||
|
|
||||||
# def select_folder():
|
# def select_folder():
|
||||||
|
@ -3,7 +3,7 @@ import shutil
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import get_folder_path
|
from .common_gui_functions import get_folder_path
|
||||||
|
|
||||||
|
|
||||||
def copy_info_to_Folders_tab(training_folder):
|
def copy_info_to_Folders_tab(training_folder):
|
||||||
|
@ -3,8 +3,8 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import (
|
from .common_gui_functions import (
|
||||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
get_file_path, get_saveasfile_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -90,8 +90,7 @@ def gradio_extract_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_model_tuned_file.click(
|
button_model_tuned_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
|
||||||
inputs=[model_tuned, model_ext, model_ext_name],
|
inputs=[model_tuned, model_ext, model_ext_name],
|
||||||
outputs=model_tuned,
|
outputs=model_tuned,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -107,7 +106,7 @@ def gradio_extract_lora_tab():
|
|||||||
)
|
)
|
||||||
button_model_org_file.click(
|
button_model_org_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda input1, input2, input3, *args, **kwargs:
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[model_org, model_ext, model_ext_name],
|
inputs=[model_org, model_ext, model_ext_name],
|
||||||
outputs=model_org,
|
outputs=model_org,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -3,8 +3,8 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import (
|
from .common_gui_functions import (
|
||||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
get_file_path, get_saveasfile_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -137,7 +137,7 @@ def gradio_extract_lycoris_locon_tab():
|
|||||||
)
|
)
|
||||||
button_db_model_file.click(
|
button_db_model_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda input1, input2, input3, *args, **kwargs:
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
inputs=[db_model, model_ext, model_ext_name],
|
inputs=[db_model, model_ext, model_ext_name],
|
||||||
outputs=db_model,
|
outputs=db_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -152,8 +152,7 @@ def gradio_extract_lycoris_locon_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_base_model_file.click(
|
button_base_model_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
|
||||||
inputs=[base_model, model_ext, model_ext_name],
|
inputs=[base_model, model_ext, model_ext_name],
|
||||||
outputs=base_model,
|
outputs=base_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -3,7 +3,7 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import get_folder_path, add_pre_postfix
|
from .common_gui_functions import get_folder_path, add_pre_postfix
|
||||||
|
|
||||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
|
|
||||||
|
@ -1,84 +1,82 @@
|
|||||||
|
import os
|
||||||
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
from tkinter import filedialog, messagebox
|
from tkinter import filedialog, messagebox
|
||||||
|
|
||||||
|
from library.common_gui_functions import tk_context
|
||||||
def open_file_dialog(initial_dir=None, initial_file=None, file_types="all"):
|
from library.common_utilities import CommonUtilities
|
||||||
file_type_filters = {
|
|
||||||
"all": [("All files", "*.*")],
|
|
||||||
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")],
|
|
||||||
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")],
|
|
||||||
"json": [("JSON files", "*.json")],
|
|
||||||
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
|
|
||||||
"directory": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
if file_types in file_type_filters:
|
|
||||||
filters = file_type_filters[file_types]
|
|
||||||
else:
|
|
||||||
filters = file_type_filters["all"]
|
|
||||||
|
|
||||||
if file_types == "directory":
|
|
||||||
return filedialog.askdirectory(initialdir=initial_dir)
|
|
||||||
else:
|
|
||||||
return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
|
|
||||||
|
|
||||||
|
|
||||||
def save_file_dialog(initial_dir, initial_file, files_type="all"):
|
class TkGui:
|
||||||
root = tk.Tk()
|
def __init__(self):
|
||||||
root.withdraw()
|
self.file_types = None
|
||||||
|
|
||||||
filetypes_switch = {
|
def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"):
|
||||||
"all": [("All files", "*.*")],
|
print(f"File types: {self.file_types}")
|
||||||
"video": [("Video files", "*.mp4;*.avi;*.mkv;*.webm;*.flv;*.mov;*.wmv")],
|
with tk_context():
|
||||||
"images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff;*.ico")],
|
self.file_types = file_types
|
||||||
"json": [("JSON files", "*.json")],
|
if self.file_types in CommonUtilities.file_filters:
|
||||||
"lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")],
|
filters = CommonUtilities.file_filters[self.file_types]
|
||||||
}
|
else:
|
||||||
|
filters = CommonUtilities.file_filters["all"]
|
||||||
|
|
||||||
filetypes = filetypes_switch.get(files_type, filetypes_switch["all"])
|
if self.file_types == "directory":
|
||||||
save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filetypes,
|
return filedialog.askdirectory(initialdir=initial_dir)
|
||||||
defaultextension=filetypes)
|
else:
|
||||||
|
return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
|
||||||
|
|
||||||
root.destroy()
|
def save_file_dialog(self, initial_dir, initial_file, file_types="all"):
|
||||||
|
self.file_types = file_types
|
||||||
|
|
||||||
return save_file_path
|
# Use the tk_context function with the 'with' statement
|
||||||
|
with tk_context():
|
||||||
|
if self.file_types in CommonUtilities.file_filters:
|
||||||
|
filters = CommonUtilities.file_filters[self.file_types]
|
||||||
|
else:
|
||||||
|
filters = CommonUtilities.file_filters["all"]
|
||||||
|
|
||||||
|
save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file,
|
||||||
|
filetypes=filters, defaultextension=".safetensors")
|
||||||
|
|
||||||
def show_message_box(_message, _title="Message", _level="info"):
|
return save_file_path
|
||||||
root = tk.Tk()
|
|
||||||
root.withdraw()
|
|
||||||
|
|
||||||
message_type = {
|
def show_message_box(_message, _title="Message", _level="info"):
|
||||||
"warning": messagebox.showwarning,
|
with tk_context():
|
||||||
"error": messagebox.showerror,
|
message_type = {
|
||||||
"info": messagebox.showinfo,
|
"warning": messagebox.showwarning,
|
||||||
"question": messagebox.askquestion,
|
"error": messagebox.showerror,
|
||||||
"okcancel": messagebox.askokcancel,
|
"info": messagebox.showinfo,
|
||||||
"retrycancel": messagebox.askretrycancel,
|
"question": messagebox.askquestion,
|
||||||
"yesno": messagebox.askyesno,
|
"okcancel": messagebox.askokcancel,
|
||||||
"yesnocancel": messagebox.askyesnocancel
|
"retrycancel": messagebox.askretrycancel,
|
||||||
}
|
"yesno": messagebox.askyesno,
|
||||||
|
"yesnocancel": messagebox.askyesnocancel
|
||||||
|
}
|
||||||
|
|
||||||
if _level in message_type:
|
if _level in message_type:
|
||||||
message_type[_level](_title, _message)
|
message_type[_level](_title, _message)
|
||||||
else:
|
else:
|
||||||
messagebox.showinfo(_title, _message)
|
messagebox.showinfo(_title, _message)
|
||||||
|
|
||||||
root.destroy()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
mode = sys.argv[1]
|
try:
|
||||||
|
mode = sys.argv[1]
|
||||||
|
|
||||||
if mode == 'file_dialog':
|
if mode == 'file_dialog':
|
||||||
starting_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
starting_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||||||
starting_file = sys.argv[3] if len(sys.argv) > 3 else None
|
starting_file = sys.argv[3] if len(sys.argv) > 3 else None
|
||||||
file_class = sys.argv[2] if len(sys.argv) > 2 else None
|
file_class = sys.argv[4] if len(sys.argv) > 4 else None # Update this to sys.argv[4]
|
||||||
file_path = open_file_dialog(starting_dir, starting_file, file_class)
|
gui = TkGui()
|
||||||
print(file_path)
|
file_path = gui.open_file_dialog(starting_dir, starting_file, file_class)
|
||||||
|
print(file_path) # Make sure to print the result
|
||||||
|
|
||||||
elif mode == 'msgbox':
|
elif mode == 'msgbox':
|
||||||
message = sys.argv[2]
|
message = sys.argv[2]
|
||||||
title = sys.argv[3] if len(sys.argv) > 3 else ""
|
title = sys.argv[3] if len(sys.argv) > 3 else ""
|
||||||
show_message_box(message, title)
|
gui = TkGui()
|
||||||
|
gui.show_message_box(message, title)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
@ -3,8 +3,8 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import (
|
from .common_gui_functions import (
|
||||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
get_file_path, get_saveasfile_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -81,8 +81,7 @@ def gradio_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_a_model_file.click(
|
button_lora_a_model_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
|
||||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_a_model,
|
outputs=lora_a_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -97,8 +96,7 @@ def gradio_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_b_model_file.click(
|
button_lora_b_model_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
|
||||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_b_model,
|
outputs=lora_b_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -3,7 +3,7 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper
|
from .common_gui_functions import get_file_path, get_saveasfile_path
|
||||||
|
|
||||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -96,8 +96,7 @@ def gradio_resize_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_a_model_file.click(
|
button_lora_a_model_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
|
||||||
inputs=[model, lora_ext, lora_ext_name],
|
inputs=[model, lora_ext, lora_ext_name],
|
||||||
outputs=model,
|
outputs=model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -3,8 +3,8 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import (
|
from .common_gui_functions import (
|
||||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
get_file_path, get_saveasfile_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -87,8 +87,7 @@ def gradio_svd_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_a_model_file.click(
|
button_lora_a_model_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
|
||||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_a_model,
|
outputs=lora_a_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
@ -103,8 +102,7 @@ def gradio_svd_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_b_model_file.click(
|
button_lora_b_model_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
|
||||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_b_model,
|
outputs=lora_b_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -3,8 +3,8 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import (
|
from .common_gui_functions import (
|
||||||
get_file_path, get_file_path_gradio_wrapper,
|
get_file_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
@ -68,8 +68,7 @@ def gradio_verify_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_lora_model_file.click(
|
button_lora_model_file.click(
|
||||||
lambda input1, input2, input3, *args, **kwargs:
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
|
||||||
inputs=[lora_model, lora_ext, lora_ext_name],
|
inputs=[lora_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_model,
|
outputs=lora_model,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
|
@ -3,7 +3,7 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .common_gui import get_folder_path
|
from .common_gui_functions import get_folder_path
|
||||||
|
|
||||||
|
|
||||||
def replace_underscore_with_space(folder_path, file_extension):
|
def replace_underscore_with_space(folder_path, file_extension):
|
||||||
|
@ -12,7 +12,7 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from library.common_gui import (
|
from library.common_gui_functions import (
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
@ -28,7 +28,7 @@ from library.common_gui import (
|
|||||||
run_cmd_training,
|
run_cmd_training,
|
||||||
# set_legacy_8bitadam,
|
# set_legacy_8bitadam,
|
||||||
update_my_data,
|
update_my_data,
|
||||||
check_if_model_exist, show_message_box, get_file_path_gradio_wrapper,
|
check_if_model_exist, show_message_box,
|
||||||
)
|
)
|
||||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
@ -254,7 +254,7 @@ def open_configuration(
|
|||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
if ask_for_file:
|
if ask_for_file:
|
||||||
file_path = get_file_path_gradio_wrapper(file_path)
|
file_path = get_file_path(file_path)
|
||||||
|
|
||||||
if not file_path == '' and not file_path == None:
|
if not file_path == '' and not file_path == None:
|
||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
|
@ -27,6 +27,6 @@ huggingface-hub==0.12.0; sys_platform != 'darwin'
|
|||||||
huggingface-hub==0.13.0; sys_platform == 'darwin'
|
huggingface-hub==0.13.0; sys_platform == 'darwin'
|
||||||
tensorflow==2.10.1; sys_platform != 'darwin'
|
tensorflow==2.10.1; sys_platform != 'darwin'
|
||||||
# For locon support
|
# For locon support
|
||||||
lycoris_lora==0.1.2
|
lycoris_lora==0.1.4
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
.
|
.
|
@ -12,7 +12,7 @@ import subprocess
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from library.common_gui import (
|
from library.common_gui_functions import (
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
@ -28,7 +28,7 @@ from library.common_gui import (
|
|||||||
gradio_source_model,
|
gradio_source_model,
|
||||||
# set_legacy_8bitadam,
|
# set_legacy_8bitadam,
|
||||||
update_my_data,
|
update_my_data,
|
||||||
check_if_model_exist, get_file_path_gradio_wrapper,
|
check_if_model_exist,
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
@ -240,9 +240,9 @@ def open_configuration(
|
|||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
if ask_for_file:
|
if ask_for_file:
|
||||||
file_path = get_file_path_gradio_wrapper(file_path)
|
file_path = get_file_path(file_path)
|
||||||
|
|
||||||
if not file_path == '' and not file_path == None:
|
if not file_path == '' and file_path is not None:
|
||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data = json.load(f)
|
my_data = json.load(f)
|
||||||
@ -673,7 +673,7 @@ def ti_tab(
|
|||||||
)
|
)
|
||||||
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
||||||
weights_file_input.click(
|
weights_file_input.click(
|
||||||
lambda *args, **kwargs: get_file_path_gradio_wrapper,
|
lambda *args, **kwargs: get_file_path(*args),
|
||||||
outputs=weights,
|
outputs=weights,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
@ -1,29 +1,27 @@
|
|||||||
# DreamBooth training
|
# DreamBooth training
|
||||||
# XXX dropped option: fine_tune
|
# XXX dropped option: fine_tune
|
||||||
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
from library.config_ml_util import (
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
from library.custom_train_functions import apply_snr_weight
|
||||||
from library.custom_train_functions import apply_snr_weight
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
|
13
train_db.py
13
train_db.py
@ -1,28 +1,25 @@
|
|||||||
# DreamBooth training
|
# DreamBooth training
|
||||||
# XXX dropped option: fine_tune
|
# XXX dropped option: fine_tune
|
||||||
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
from library.config_ml_util import (
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,31 +1,30 @@
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
import importlib
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.train_util import (
|
from library.config_ml_util import (
|
||||||
DreamBoothDataset,
|
|
||||||
)
|
|
||||||
import library.config_util as config_util
|
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
from library.custom_train_functions import apply_snr_weight
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.train_util import (
|
||||||
|
DreamBoothDataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
|
@ -1,31 +1,30 @@
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
import importlib
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.train_util import (
|
from library.config_ml_util import (
|
||||||
DreamBoothDataset,
|
|
||||||
)
|
|
||||||
import library.config_util as config_util
|
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
from library.custom_train_functions import apply_snr_weight
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.train_util import (
|
||||||
|
DreamBoothDataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
|
@ -1,24 +1,21 @@
|
|||||||
import importlib
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
from library.config_ml_util import (
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
|
@ -1,24 +1,21 @@
|
|||||||
import importlib
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import library.config_ml_util as config_util
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
from library.config_ml_util import (
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
|
Loading…
Reference in New Issue
Block a user