979 lines
33 KiB
Python
979 lines
33 KiB
Python
from tkinter import filedialog, Tk
|
|
from easygui import msgbox
|
|
import os
|
|
import gradio as gr
|
|
import easygui
|
|
import shutil
|
|
|
|
folder_symbol = '\U0001f4c2' # 📂
|
|
refresh_symbol = '\U0001f504' # 🔄
|
|
save_style_symbol = '\U0001f4be' # 💾
|
|
document_symbol = '\U0001F4C4' # 📄
|
|
|
|
# define a list of substrings to search for v2 base models
|
|
V2_BASE_MODELS = [
|
|
'stabilityai/stable-diffusion-2-1-base',
|
|
'stabilityai/stable-diffusion-2-base',
|
|
]
|
|
|
|
# define a list of substrings to search for v_parameterization models
|
|
V_PARAMETERIZATION_MODELS = [
|
|
'stabilityai/stable-diffusion-2-1',
|
|
'stabilityai/stable-diffusion-2',
|
|
]
|
|
|
|
# define a list of substrings to v1.x models
|
|
V1_MODELS = [
|
|
'CompVis/stable-diffusion-v1-4',
|
|
'runwayml/stable-diffusion-v1-5',
|
|
]
|
|
|
|
# define a list of substrings to search for
|
|
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
|
|
|
|
FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_ENVIRONMENT']
|
|
|
|
|
|
def check_if_model_exist(output_name, output_dir, save_model_as):
|
|
if save_model_as in ['diffusers', 'diffusers_safetendors']:
|
|
ckpt_folder = os.path.join(output_dir, output_name)
|
|
if os.path.isdir(ckpt_folder):
|
|
msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?'
|
|
if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
|
|
print(
|
|
'Aborting training due to existing model with same name...'
|
|
)
|
|
return True
|
|
elif save_model_as in ['ckpt', 'safetensors']:
|
|
ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as)
|
|
if os.path.isfile(ckpt_file):
|
|
msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?'
|
|
if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
|
|
print(
|
|
'Aborting training due to existing model with same name...'
|
|
)
|
|
return True
|
|
else:
|
|
print(
|
|
'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...'
|
|
)
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
def update_my_data(my_data):
|
|
# Update the optimizer based on the use_8bit_adam flag
|
|
use_8bit_adam = my_data.get('use_8bit_adam', False)
|
|
my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW')
|
|
|
|
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
|
|
model_list = my_data.get('model_list', [])
|
|
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
|
|
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
|
|
my_data['model_list'] = 'custom'
|
|
|
|
# Convert epoch and save_every_n_epochs values to int if they are strings
|
|
for key in ['epoch', 'save_every_n_epochs']:
|
|
value = my_data.get(key, -1)
|
|
if isinstance(value, str) and value.isdigit():
|
|
my_data[key] = int(value)
|
|
elif not value:
|
|
my_data[key] = -1
|
|
|
|
# Update LoRA_type if it is set to LoCon
|
|
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
|
|
my_data['LoRA_type'] = 'LyCORIS/LoCon'
|
|
|
|
# Update model save choices due to changes for LoRA and TI training
|
|
if (
|
|
(my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
|
|
and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
|
|
):
|
|
message = (
|
|
'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
|
|
)
|
|
if my_data.get('LoRA_type'):
|
|
print(message.format('LoRA'))
|
|
if my_data.get('num_vectors_per_token'):
|
|
print(message.format('TI'))
|
|
my_data['save_model_as'] = 'safetensors'
|
|
|
|
return my_data
|
|
|
|
|
|
def get_dir_and_file(file_path):
|
|
dir_path, file_name = os.path.split(file_path)
|
|
return (dir_path, file_name)
|
|
|
|
|
|
# def has_ext_files(directory, extension):
|
|
# # Iterate through all the files in the directory
|
|
# for file in os.listdir(directory):
|
|
# # If the file name ends with extension, return True
|
|
# if file.endswith(extension):
|
|
# return True
|
|
# # If no extension files were found, return False
|
|
# return False
|
|
|
|
|
|
def get_file_path(
|
|
file_path='', default_extension='.json', extension_name='Config files'
|
|
):
|
|
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
|
current_file_path = file_path
|
|
# print(f'current file path: {current_file_path}')
|
|
|
|
initial_dir, initial_file = get_dir_and_file(file_path)
|
|
|
|
# Create a hidden Tkinter root window
|
|
root = Tk()
|
|
root.wm_attributes('-topmost', 1)
|
|
root.withdraw()
|
|
|
|
# Show the open file dialog and get the selected file path
|
|
file_path = filedialog.askopenfilename(
|
|
filetypes=(
|
|
(extension_name, f'*{default_extension}'),
|
|
('All files', '*.*'),
|
|
),
|
|
defaultextension=default_extension,
|
|
initialfile=initial_file,
|
|
initialdir=initial_dir,
|
|
)
|
|
|
|
# Destroy the hidden root window
|
|
root.destroy()
|
|
|
|
# If no file is selected, use the current file path
|
|
if not file_path:
|
|
file_path = current_file_path
|
|
current_file_path = file_path
|
|
# print(f'current file path: {current_file_path}')
|
|
|
|
return file_path
|
|
|
|
|
|
def get_any_file_path(file_path=''):
|
|
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
|
current_file_path = file_path
|
|
# print(f'current file path: {current_file_path}')
|
|
|
|
initial_dir, initial_file = get_dir_and_file(file_path)
|
|
|
|
root = Tk()
|
|
root.wm_attributes('-topmost', 1)
|
|
root.withdraw()
|
|
file_path = filedialog.askopenfilename(
|
|
initialdir=initial_dir,
|
|
initialfile=initial_file,
|
|
)
|
|
root.destroy()
|
|
|
|
if file_path == '':
|
|
file_path = current_file_path
|
|
|
|
return file_path
|
|
|
|
|
|
def remove_doublequote(file_path):
|
|
if file_path != None:
|
|
file_path = file_path.replace('"', '')
|
|
|
|
return file_path
|
|
|
|
|
|
# def set_legacy_8bitadam(optimizer, use_8bit_adam):
|
|
# if optimizer == 'AdamW8bit':
|
|
# # use_8bit_adam = True
|
|
# return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
|
|
# value=True, interactive=False, visible=True
|
|
# )
|
|
# else:
|
|
# # use_8bit_adam = False
|
|
# return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
|
|
# value=False, interactive=False, visible=True
|
|
# )
|
|
|
|
|
|
def get_folder_path(folder_path=''):
|
|
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
|
current_folder_path = folder_path
|
|
|
|
initial_dir, initial_file = get_dir_and_file(folder_path)
|
|
|
|
root = Tk()
|
|
root.wm_attributes('-topmost', 1)
|
|
root.withdraw()
|
|
folder_path = filedialog.askdirectory(initialdir=initial_dir)
|
|
root.destroy()
|
|
|
|
if folder_path == '':
|
|
folder_path = current_folder_path
|
|
|
|
return folder_path
|
|
|
|
|
|
def get_saveasfile_path(
|
|
file_path='', defaultextension='.json', extension_name='Config files'
|
|
):
|
|
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
|
current_file_path = file_path
|
|
# print(f'current file path: {current_file_path}')
|
|
|
|
initial_dir, initial_file = get_dir_and_file(file_path)
|
|
|
|
root = Tk()
|
|
root.wm_attributes('-topmost', 1)
|
|
root.withdraw()
|
|
save_file_path = filedialog.asksaveasfile(
|
|
filetypes=(
|
|
(f'{extension_name}', f'{defaultextension}'),
|
|
('All files', '*'),
|
|
),
|
|
defaultextension=defaultextension,
|
|
initialdir=initial_dir,
|
|
initialfile=initial_file,
|
|
)
|
|
root.destroy()
|
|
|
|
# print(save_file_path)
|
|
|
|
if save_file_path == None:
|
|
file_path = current_file_path
|
|
else:
|
|
print(save_file_path.name)
|
|
file_path = save_file_path.name
|
|
|
|
# print(file_path)
|
|
|
|
return file_path
|
|
|
|
|
|
def get_saveasfilename_path(
|
|
file_path='', extensions='*', extension_name='Config files'
|
|
):
|
|
if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
|
|
current_file_path = file_path
|
|
# print(f'current file path: {current_file_path}')
|
|
|
|
initial_dir, initial_file = get_dir_and_file(file_path)
|
|
|
|
root = Tk()
|
|
root.wm_attributes('-topmost', 1)
|
|
root.withdraw()
|
|
save_file_path = filedialog.asksaveasfilename(
|
|
filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
|
|
defaultextension=extensions,
|
|
initialdir=initial_dir,
|
|
initialfile=initial_file,
|
|
)
|
|
root.destroy()
|
|
|
|
if save_file_path == '':
|
|
file_path = current_file_path
|
|
else:
|
|
# print(save_file_path)
|
|
file_path = save_file_path
|
|
|
|
return file_path
|
|
|
|
|
|
def add_pre_postfix(
|
|
folder: str = '',
|
|
prefix: str = '',
|
|
postfix: str = '',
|
|
caption_file_ext: str = '.caption',
|
|
) -> None:
|
|
"""
|
|
Add prefix and/or postfix to the content of caption files within a folder.
|
|
If no caption files are found, create one with the requested prefix and/or postfix.
|
|
|
|
Args:
|
|
folder (str): Path to the folder containing caption files.
|
|
prefix (str, optional): Prefix to add to the content of the caption files.
|
|
postfix (str, optional): Postfix to add to the content of the caption files.
|
|
caption_file_ext (str, optional): Extension of the caption files.
|
|
"""
|
|
|
|
if prefix == '' and postfix == '':
|
|
return
|
|
|
|
image_extensions = ('.jpg', '.jpeg', '.png', '.webp')
|
|
image_files = [
|
|
f for f in os.listdir(folder) if f.lower().endswith(image_extensions)
|
|
]
|
|
|
|
for image_file in image_files:
|
|
caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext
|
|
caption_file_path = os.path.join(folder, caption_file_name)
|
|
|
|
if not os.path.exists(caption_file_path):
|
|
with open(caption_file_path, 'w') as f:
|
|
separator = ' ' if prefix and postfix else ''
|
|
f.write(f'{prefix}{separator}{postfix}')
|
|
else:
|
|
with open(caption_file_path, 'r+') as f:
|
|
content = f.read()
|
|
content = content.rstrip()
|
|
f.seek(0, 0)
|
|
|
|
prefix_separator = ' ' if prefix else ''
|
|
postfix_separator = ' ' if postfix else ''
|
|
f.write(
|
|
f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}'
|
|
)
|
|
|
|
|
|
def has_ext_files(folder_path: str, file_extension: str) -> bool:
|
|
"""
|
|
Check if there are any files with the specified extension in the given folder.
|
|
|
|
Args:
|
|
folder_path (str): Path to the folder containing files.
|
|
file_extension (str): Extension of the files to look for.
|
|
|
|
Returns:
|
|
bool: True if files with the specified extension are found, False otherwise.
|
|
"""
|
|
for file in os.listdir(folder_path):
|
|
if file.endswith(file_extension):
|
|
return True
|
|
return False
|
|
|
|
|
|
def find_replace(
|
|
folder_path: str = '',
|
|
caption_file_ext: str = '.caption',
|
|
search_text: str = '',
|
|
replace_text: str = '',
|
|
) -> None:
|
|
"""
|
|
Find and replace text in caption files within a folder.
|
|
|
|
Args:
|
|
folder_path (str, optional): Path to the folder containing caption files.
|
|
caption_file_ext (str, optional): Extension of the caption files.
|
|
search_text (str, optional): Text to search for in the caption files.
|
|
replace_text (str, optional): Text to replace the search text with.
|
|
"""
|
|
print('Running caption find/replace')
|
|
|
|
if not has_ext_files(folder_path, caption_file_ext):
|
|
msgbox(
|
|
f'No files with extension {caption_file_ext} were found in {folder_path}...'
|
|
)
|
|
return
|
|
|
|
if search_text == '':
|
|
return
|
|
|
|
caption_files = [
|
|
f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)
|
|
]
|
|
|
|
for caption_file in caption_files:
|
|
with open(
|
|
os.path.join(folder_path, caption_file), 'r', errors='ignore'
|
|
) as f:
|
|
content = f.read()
|
|
|
|
content = content.replace(search_text, replace_text)
|
|
|
|
with open(os.path.join(folder_path, caption_file), 'w') as f:
|
|
f.write(content)
|
|
|
|
|
|
def color_aug_changed(color_aug):
|
|
if color_aug:
|
|
msgbox(
|
|
'Disabling "Cache latent" because "Color augmentation" has been selected...'
|
|
)
|
|
return gr.Checkbox.update(value=False, interactive=False)
|
|
else:
|
|
return gr.Checkbox.update(value=True, interactive=True)
|
|
|
|
|
|
def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
|
# List all files in the directory
|
|
files = os.listdir(output_dir)
|
|
|
|
# Iterate over the list of files
|
|
for file in files:
|
|
# Check if the file starts with the value of output_name
|
|
if file.startswith(output_name):
|
|
# Check if it is a file or a directory
|
|
if os.path.isfile(os.path.join(output_dir, file)):
|
|
# Split the file name and extension
|
|
file_name, ext = os.path.splitext(file)
|
|
|
|
# Copy the v2-inference-v.yaml file to the current file, with a .yaml extension
|
|
if v2 and v_parameterization:
|
|
print(
|
|
f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml'
|
|
)
|
|
shutil.copy(
|
|
f'./v2_inference/v2-inference-v.yaml',
|
|
f'{output_dir}/{file_name}.yaml',
|
|
)
|
|
elif v2:
|
|
print(
|
|
f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml'
|
|
)
|
|
shutil.copy(
|
|
f'./v2_inference/v2-inference.yaml',
|
|
f'{output_dir}/{file_name}.yaml',
|
|
)
|
|
|
|
|
|
def set_pretrained_model_name_or_path_input(
|
|
model_list, pretrained_model_name_or_path, v2, v_parameterization
|
|
):
|
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
|
if str(model_list) in V2_BASE_MODELS:
|
|
print('SD v2 model detected. Setting --v2 parameter')
|
|
v2 = True
|
|
v_parameterization = False
|
|
pretrained_model_name_or_path = str(model_list)
|
|
|
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
|
if str(model_list) in V_PARAMETERIZATION_MODELS:
|
|
print(
|
|
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
|
|
)
|
|
v2 = True
|
|
v_parameterization = True
|
|
pretrained_model_name_or_path = str(model_list)
|
|
|
|
if str(model_list) in V1_MODELS:
|
|
v2 = False
|
|
v_parameterization = False
|
|
pretrained_model_name_or_path = str(model_list)
|
|
|
|
if model_list == 'custom':
|
|
if (
|
|
str(pretrained_model_name_or_path) in V1_MODELS
|
|
or str(pretrained_model_name_or_path) in V2_BASE_MODELS
|
|
or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS
|
|
):
|
|
pretrained_model_name_or_path = ''
|
|
v2 = False
|
|
v_parameterization = False
|
|
return model_list, pretrained_model_name_or_path, v2, v_parameterization
|
|
|
|
|
|
def set_v2_checkbox(model_list, v2, v_parameterization):
|
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
|
if str(model_list) in V2_BASE_MODELS:
|
|
v2 = True
|
|
v_parameterization = False
|
|
|
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
|
if str(model_list) in V_PARAMETERIZATION_MODELS:
|
|
v2 = True
|
|
v_parameterization = True
|
|
|
|
if str(model_list) in V1_MODELS:
|
|
v2 = False
|
|
v_parameterization = False
|
|
|
|
return v2, v_parameterization
|
|
|
|
|
|
def set_model_list(
|
|
model_list,
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
):
|
|
|
|
if not pretrained_model_name_or_path in ALL_PRESET_MODELS:
|
|
model_list = 'custom'
|
|
else:
|
|
model_list = pretrained_model_name_or_path
|
|
|
|
return model_list, v2, v_parameterization
|
|
|
|
|
|
###
|
|
### Gradio common GUI section
|
|
###
|
|
|
|
|
|
def gradio_config():
|
|
with gr.Accordion('Configuration file', open=False):
|
|
with gr.Row():
|
|
button_open_config = gr.Button('Open 📂', elem_id='open_folder')
|
|
button_save_config = gr.Button('Save 💾', elem_id='open_folder')
|
|
button_save_as_config = gr.Button(
|
|
'Save as... 💾', elem_id='open_folder'
|
|
)
|
|
config_file_name = gr.Textbox(
|
|
label='',
|
|
placeholder="type the configuration file path or use the 'Open' button above to select it...",
|
|
interactive=True,
|
|
)
|
|
button_load_config = gr.Button('Load 💾', elem_id='open_folder')
|
|
config_file_name.change(
|
|
remove_doublequote,
|
|
inputs=[config_file_name],
|
|
outputs=[config_file_name],
|
|
)
|
|
return (
|
|
button_open_config,
|
|
button_save_config,
|
|
button_save_as_config,
|
|
config_file_name,
|
|
button_load_config,
|
|
)
|
|
|
|
|
|
def get_pretrained_model_name_or_path_file(
|
|
model_list, pretrained_model_name_or_path
|
|
):
|
|
pretrained_model_name_or_path = get_any_file_path(
|
|
pretrained_model_name_or_path
|
|
)
|
|
set_model_list(model_list, pretrained_model_name_or_path)
|
|
|
|
|
|
def gradio_source_model(save_model_as_choices = [
|
|
'same as source model',
|
|
'ckpt',
|
|
'diffusers',
|
|
'diffusers_safetensors',
|
|
'safetensors',
|
|
]):
|
|
with gr.Tab('Source model'):
|
|
# Define the input elements
|
|
with gr.Row():
|
|
pretrained_model_name_or_path = gr.Textbox(
|
|
label='Pretrained model name or path',
|
|
placeholder='enter the path to custom model or name of pretrained model',
|
|
value='runwayml/stable-diffusion-v1-5',
|
|
)
|
|
pretrained_model_name_or_path_file = gr.Button(
|
|
document_symbol, elem_id='open_folder_small'
|
|
)
|
|
pretrained_model_name_or_path_file.click(
|
|
get_any_file_path,
|
|
inputs=pretrained_model_name_or_path,
|
|
outputs=pretrained_model_name_or_path,
|
|
show_progress=False,
|
|
)
|
|
pretrained_model_name_or_path_folder = gr.Button(
|
|
folder_symbol, elem_id='open_folder_small'
|
|
)
|
|
pretrained_model_name_or_path_folder.click(
|
|
get_folder_path,
|
|
inputs=pretrained_model_name_or_path,
|
|
outputs=pretrained_model_name_or_path,
|
|
show_progress=False,
|
|
)
|
|
model_list = gr.Dropdown(
|
|
label='Model Quick Pick',
|
|
choices=[
|
|
'custom',
|
|
'stabilityai/stable-diffusion-2-1-base',
|
|
'stabilityai/stable-diffusion-2-base',
|
|
'stabilityai/stable-diffusion-2-1',
|
|
'stabilityai/stable-diffusion-2',
|
|
'runwayml/stable-diffusion-v1-5',
|
|
'CompVis/stable-diffusion-v1-4',
|
|
],
|
|
value='runwayml/stable-diffusion-v1-5',
|
|
)
|
|
save_model_as = gr.Dropdown(
|
|
label='Save trained model as',
|
|
choices=save_model_as_choices,
|
|
value='safetensors',
|
|
)
|
|
|
|
with gr.Row():
|
|
v2 = gr.Checkbox(label='v2', value=False)
|
|
v_parameterization = gr.Checkbox(
|
|
label='v_parameterization', value=False
|
|
)
|
|
v2.change(
|
|
set_v2_checkbox,
|
|
inputs=[model_list, v2, v_parameterization],
|
|
outputs=[v2, v_parameterization],
|
|
show_progress=False,
|
|
)
|
|
v_parameterization.change(
|
|
set_v2_checkbox,
|
|
inputs=[model_list, v2, v_parameterization],
|
|
outputs=[v2, v_parameterization],
|
|
show_progress=False,
|
|
)
|
|
model_list.change(
|
|
set_pretrained_model_name_or_path_input,
|
|
inputs=[
|
|
model_list,
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
],
|
|
outputs=[
|
|
model_list,
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
],
|
|
show_progress=False,
|
|
)
|
|
# Update the model list and parameters when user click outside the button or field
|
|
pretrained_model_name_or_path.change(
|
|
set_model_list,
|
|
inputs=[
|
|
model_list,
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
],
|
|
outputs=[
|
|
model_list,
|
|
v2,
|
|
v_parameterization,
|
|
],
|
|
show_progress=False,
|
|
)
|
|
return (
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
save_model_as,
|
|
model_list,
|
|
)
|
|
|
|
|
|
def gradio_training(
|
|
learning_rate_value='1e-6',
|
|
lr_scheduler_value='constant',
|
|
lr_warmup_value='0',
|
|
):
|
|
with gr.Row():
|
|
train_batch_size = gr.Slider(
|
|
minimum=1,
|
|
maximum=64,
|
|
label='Train batch size',
|
|
value=1,
|
|
step=1,
|
|
)
|
|
epoch = gr.Number(label='Epoch', value=1, precision=0)
|
|
save_every_n_epochs = gr.Number(
|
|
label='Save every N epochs', value=1, precision=0
|
|
)
|
|
caption_extension = gr.Textbox(
|
|
label='Caption Extension',
|
|
placeholder='(Optional) Extension for caption files. default: .caption',
|
|
)
|
|
with gr.Row():
|
|
mixed_precision = gr.Dropdown(
|
|
label='Mixed precision',
|
|
choices=[
|
|
'no',
|
|
'fp16',
|
|
'bf16',
|
|
],
|
|
value='fp16',
|
|
)
|
|
save_precision = gr.Dropdown(
|
|
label='Save precision',
|
|
choices=[
|
|
'float',
|
|
'fp16',
|
|
'bf16',
|
|
],
|
|
value='fp16',
|
|
)
|
|
num_cpu_threads_per_process = gr.Slider(
|
|
minimum=1,
|
|
maximum=os.cpu_count(),
|
|
step=1,
|
|
label='Number of CPU threads per core',
|
|
value=2,
|
|
)
|
|
seed = gr.Textbox(label='Seed', placeholder='(Optional) eg:1234')
|
|
cache_latents = gr.Checkbox(label='Cache latent', value=True)
|
|
with gr.Row():
|
|
learning_rate = gr.Textbox(
|
|
label='Learning rate', value=learning_rate_value
|
|
)
|
|
lr_scheduler = gr.Dropdown(
|
|
label='LR Scheduler',
|
|
choices=[
|
|
'adafactor',
|
|
'constant',
|
|
'constant_with_warmup',
|
|
'cosine',
|
|
'cosine_with_restarts',
|
|
'linear',
|
|
'polynomial',
|
|
],
|
|
value=lr_scheduler_value,
|
|
)
|
|
lr_warmup = gr.Textbox(
|
|
label='LR warmup (% of steps)', value=lr_warmup_value
|
|
)
|
|
optimizer = gr.Dropdown(
|
|
label='Optimizer',
|
|
choices=[
|
|
'AdamW',
|
|
'AdamW8bit',
|
|
'Adafactor',
|
|
'DAdaptation',
|
|
'Lion',
|
|
'SGDNesterov',
|
|
'SGDNesterov8bit',
|
|
],
|
|
value='AdamW8bit',
|
|
interactive=True,
|
|
)
|
|
with gr.Row():
|
|
optimizer_args = gr.Textbox(
|
|
label='Optimizer extra arguments',
|
|
placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True',
|
|
)
|
|
return (
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
num_cpu_threads_per_process,
|
|
seed,
|
|
caption_extension,
|
|
cache_latents,
|
|
optimizer,
|
|
optimizer_args,
|
|
)
|
|
|
|
|
|
def run_cmd_training(**kwargs):
|
|
options = [
|
|
f' --learning_rate="{kwargs.get("learning_rate", "")}"'
|
|
if kwargs.get('learning_rate')
|
|
else '',
|
|
f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"'
|
|
if kwargs.get('lr_scheduler')
|
|
else '',
|
|
f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"'
|
|
if kwargs.get('lr_warmup_steps')
|
|
else '',
|
|
f' --train_batch_size="{kwargs.get("train_batch_size", "")}"'
|
|
if kwargs.get('train_batch_size')
|
|
else '',
|
|
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
|
|
if kwargs.get('max_train_steps')
|
|
else '',
|
|
f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"'
|
|
if int(kwargs.get('save_every_n_epochs'))
|
|
else '',
|
|
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
|
|
if kwargs.get('mixed_precision')
|
|
else '',
|
|
f' --save_precision="{kwargs.get("save_precision", "")}"'
|
|
if kwargs.get('save_precision')
|
|
else '',
|
|
f' --seed="{kwargs.get("seed", "")}"'
|
|
if kwargs.get('seed') != ''
|
|
else '',
|
|
f' --caption_extension="{kwargs.get("caption_extension", "")}"'
|
|
if kwargs.get('caption_extension')
|
|
else '',
|
|
' --cache_latents' if kwargs.get('cache_latents') else '',
|
|
# ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
|
|
f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"',
|
|
f' --optimizer_args {kwargs.get("optimizer_args", "")}'
|
|
if not kwargs.get('optimizer_args') == ''
|
|
else '',
|
|
]
|
|
run_cmd = ''.join(options)
|
|
return run_cmd
|
|
|
|
|
|
def gradio_advanced_training():
|
|
with gr.Row():
|
|
additional_parameters = gr.Textbox(
|
|
label='Additional parameters',
|
|
placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"',
|
|
)
|
|
with gr.Row():
|
|
keep_tokens = gr.Slider(
|
|
label='Keep n tokens', value='0', minimum=0, maximum=32, step=1
|
|
)
|
|
clip_skip = gr.Slider(
|
|
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
|
)
|
|
max_token_length = gr.Dropdown(
|
|
label='Max Token Length',
|
|
choices=[
|
|
'75',
|
|
'150',
|
|
'225',
|
|
],
|
|
value='75',
|
|
)
|
|
full_fp16 = gr.Checkbox(
|
|
label='Full fp16 training (experimental)', value=False
|
|
)
|
|
with gr.Row():
|
|
gradient_checkpointing = gr.Checkbox(
|
|
label='Gradient checkpointing', value=False
|
|
)
|
|
shuffle_caption = gr.Checkbox(label='Shuffle caption', value=False)
|
|
persistent_data_loader_workers = gr.Checkbox(
|
|
label='Persistent data loader', value=False
|
|
)
|
|
mem_eff_attn = gr.Checkbox(
|
|
label='Memory efficient attention', value=False
|
|
)
|
|
with gr.Row():
|
|
# This use_8bit_adam element should be removed in a future release as it is no longer used
|
|
# use_8bit_adam = gr.Checkbox(
|
|
# label='Use 8bit adam', value=False, visible=False
|
|
# )
|
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
|
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
|
min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1)
|
|
with gr.Row():
|
|
bucket_no_upscale = gr.Checkbox(
|
|
label="Don't upscale bucket resolution", value=True
|
|
)
|
|
bucket_reso_steps = gr.Number(
|
|
label='Bucket resolution steps', value=64
|
|
)
|
|
random_crop = gr.Checkbox(
|
|
label='Random crop instead of center crop', value=False
|
|
)
|
|
noise_offset = gr.Textbox(
|
|
label='Noise offset (0 - 1)', placeholder='(Oprional) eg: 0.1'
|
|
)
|
|
|
|
with gr.Row():
|
|
caption_dropout_every_n_epochs = gr.Number(
|
|
label='Dropout caption every n epochs', value=0
|
|
)
|
|
caption_dropout_rate = gr.Slider(
|
|
label='Rate of caption dropout', value=0, minimum=0, maximum=1
|
|
)
|
|
vae_batch_size = gr.Slider(
|
|
label='VAE batch size',
|
|
minimum=0,
|
|
maximum=32,
|
|
value=0,
|
|
every=1
|
|
)
|
|
with gr.Row():
|
|
save_state = gr.Checkbox(label='Save training state', value=False)
|
|
resume = gr.Textbox(
|
|
label='Resume from saved training state',
|
|
placeholder='path to "last-state" state folder to resume from',
|
|
)
|
|
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
|
resume_button.click(
|
|
get_folder_path,
|
|
outputs=resume,
|
|
show_progress=False,
|
|
)
|
|
max_train_epochs = gr.Textbox(
|
|
label='Max train epoch',
|
|
placeholder='(Optional) Override number of epoch',
|
|
)
|
|
max_data_loader_n_workers = gr.Textbox(
|
|
label='Max num workers for DataLoader',
|
|
placeholder='(Optional) Override number of epoch. Default: 8',
|
|
value="0",
|
|
)
|
|
return (
|
|
# use_8bit_adam,
|
|
xformers,
|
|
full_fp16,
|
|
gradient_checkpointing,
|
|
shuffle_caption,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
mem_eff_attn,
|
|
save_state,
|
|
resume,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
noise_offset,
|
|
additional_parameters,
|
|
vae_batch_size,
|
|
min_snr_gamma,
|
|
)
|
|
|
|
|
|
def run_cmd_advanced_training(**kwargs):
|
|
options = [
|
|
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
|
|
if kwargs.get('max_train_epochs')
|
|
else '',
|
|
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
|
|
if kwargs.get('max_data_loader_n_workers')
|
|
else '',
|
|
f' --max_token_length={kwargs.get("max_token_length", "")}'
|
|
if int(kwargs.get('max_token_length', 75)) > 75
|
|
else '',
|
|
f' --clip_skip={kwargs.get("clip_skip", "")}'
|
|
if int(kwargs.get('clip_skip', 1)) > 1
|
|
else '',
|
|
f' --resume="{kwargs.get("resume", "")}"'
|
|
if kwargs.get('resume')
|
|
else '',
|
|
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
|
|
if int(kwargs.get('keep_tokens', 0)) > 0
|
|
else '',
|
|
f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
|
|
if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
|
|
else '',
|
|
f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
|
|
if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
|
|
else '',
|
|
f' --vae_batch_size="{kwargs.get("vae_batch_size", 0)}"'
|
|
if int(kwargs.get('vae_batch_size', 0)) > 0
|
|
else '',
|
|
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
|
|
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
|
|
else '',
|
|
f' --min_snr_gamma={int(kwargs.get("min_snr_gamma", 0))}'
|
|
if int(kwargs.get('min_snr_gamma', 0)) >= 1
|
|
else '',
|
|
' --save_state' if kwargs.get('save_state') else '',
|
|
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
|
' --color_aug' if kwargs.get('color_aug') else '',
|
|
' --flip_aug' if kwargs.get('flip_aug') else '',
|
|
' --shuffle_caption' if kwargs.get('shuffle_caption') else '',
|
|
' --gradient_checkpointing' if kwargs.get('gradient_checkpointing')
|
|
else '',
|
|
' --full_fp16' if kwargs.get('full_fp16') else '',
|
|
' --xformers' if kwargs.get('xformers') else '',
|
|
# ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
|
|
' --persistent_data_loader_workers'
|
|
if kwargs.get('persistent_data_loader_workers')
|
|
else '',
|
|
' --bucket_no_upscale' if kwargs.get('bucket_no_upscale') else '',
|
|
' --random_crop' if kwargs.get('random_crop') else '',
|
|
f' --noise_offset={float(kwargs.get("noise_offset", 0))}'
|
|
if not kwargs.get('noise_offset', '') == ''
|
|
else '',
|
|
f' {kwargs.get("additional_parameters", "")}',
|
|
]
|
|
run_cmd = ''.join(options)
|
|
return run_cmd
|