7b5639cff5
This is a massive WIP and should not be trusted or used right now. However, major milestones have been crossed. Both message boxes and file dialogs are now properly subprocessed and work on macOS. I think by extension, it may work on runpod environments as well, but that remains to be tested.
999 lines
28 KiB
Python
999 lines
28 KiB
Python
# v1: initial release
|
|
# v2: add open and save folder icons
|
|
# v3: Add new Utilities tab for Dreambooth folder preparation
|
|
# v3.1: Adding captionning of images to utilities
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import pathlib
|
|
import subprocess
|
|
|
|
import gradio as gr
|
|
|
|
from library.common_gui import (
|
|
get_folder_path,
|
|
remove_doublequote,
|
|
get_file_path,
|
|
get_any_file_path,
|
|
get_saveasfile_path,
|
|
color_aug_changed,
|
|
save_inference_file,
|
|
gradio_advanced_training,
|
|
run_cmd_advanced_training,
|
|
run_cmd_training,
|
|
gradio_training,
|
|
gradio_config,
|
|
gradio_source_model,
|
|
# set_legacy_8bitadam,
|
|
update_my_data,
|
|
check_if_model_exist,
|
|
)
|
|
from library.dreambooth_folder_creation_gui import (
|
|
gradio_dreambooth_folder_creation_tab,
|
|
)
|
|
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
|
from library.tensorboard_gui import (
|
|
gradio_tensorboard,
|
|
start_tensorboard,
|
|
stop_tensorboard,
|
|
)
|
|
from library.utilities import utilities_tab
|
|
|
|
folder_symbol = '\U0001f4c2' # 📂
|
|
refresh_symbol = '\U0001f504' # 🔄
|
|
save_style_symbol = '\U0001f4be' # 💾
|
|
document_symbol = '\U0001F4C4' # 📄
|
|
|
|
|
|
def save_configuration(
|
|
save_as,
|
|
file_path,
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
logging_dir,
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
max_resolution,
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
seed,
|
|
num_cpu_threads_per_process,
|
|
cache_latents,
|
|
caption_extension,
|
|
enable_bucket,
|
|
gradient_checkpointing,
|
|
full_fp16,
|
|
no_token_padding,
|
|
stop_text_encoder_training,
|
|
# use_8bit_adam,
|
|
xformers,
|
|
save_model_as,
|
|
shuffle_caption,
|
|
save_state,
|
|
resume,
|
|
prior_loss_weight,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
vae,
|
|
output_name,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
mem_eff_attn,
|
|
gradient_accumulation_steps,
|
|
model_list,
|
|
token_string,
|
|
init_word,
|
|
num_vectors_per_token,
|
|
max_train_steps,
|
|
weights,
|
|
template,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
optimizer,
|
|
optimizer_args,
|
|
noise_offset,
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
additional_parameters,vae_batch_size,
|
|
):
|
|
# Get list of function parameters and values
|
|
parameters = list(locals().items())
|
|
|
|
original_file_path = file_path
|
|
|
|
save_as_bool = True if save_as.get('label') == 'True' else False
|
|
|
|
if save_as_bool:
|
|
print('Save as...')
|
|
file_path = get_saveasfile_path(file_path)
|
|
else:
|
|
print('Save...')
|
|
if file_path == None or file_path == '':
|
|
file_path = get_saveasfile_path(file_path)
|
|
|
|
# print(file_path)
|
|
|
|
if file_path == None or file_path == '':
|
|
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
|
|
|
# Return the values of the variables as a dictionary
|
|
variables = {
|
|
name: value
|
|
for name, value in parameters # locals().items()
|
|
if name
|
|
not in [
|
|
'file_path',
|
|
'save_as',
|
|
]
|
|
}
|
|
|
|
# Extract the destination directory from the file path
|
|
destination_directory = os.path.dirname(file_path)
|
|
|
|
# Create the destination directory if it doesn't exist
|
|
if not os.path.exists(destination_directory):
|
|
os.makedirs(destination_directory)
|
|
|
|
# Save the data to the selected file
|
|
with open(file_path, 'w') as file:
|
|
json.dump(variables, file, indent=2)
|
|
|
|
return file_path
|
|
|
|
|
|
def open_configuration(
|
|
ask_for_file,
|
|
file_path,
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
logging_dir,
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
max_resolution,
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
seed,
|
|
num_cpu_threads_per_process,
|
|
cache_latents,
|
|
caption_extension,
|
|
enable_bucket,
|
|
gradient_checkpointing,
|
|
full_fp16,
|
|
no_token_padding,
|
|
stop_text_encoder_training,
|
|
# use_8bit_adam,
|
|
xformers,
|
|
save_model_as,
|
|
shuffle_caption,
|
|
save_state,
|
|
resume,
|
|
prior_loss_weight,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
vae,
|
|
output_name,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
mem_eff_attn,
|
|
gradient_accumulation_steps,
|
|
model_list,
|
|
token_string,
|
|
init_word,
|
|
num_vectors_per_token,
|
|
max_train_steps,
|
|
weights,
|
|
template,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
optimizer,
|
|
optimizer_args,
|
|
noise_offset,
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
additional_parameters,vae_batch_size,
|
|
):
|
|
# Get list of function parameters and values
|
|
parameters = list(locals().items())
|
|
|
|
ask_for_file = True if ask_for_file.get('label') == 'True' else False
|
|
|
|
original_file_path = file_path
|
|
|
|
if ask_for_file:
|
|
file_path = get_file_path(file_path)
|
|
|
|
if not file_path == '' and not file_path == None:
|
|
# load variables from JSON file
|
|
with open(file_path, 'r') as f:
|
|
my_data = json.load(f)
|
|
print('Loading config...')
|
|
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
|
my_data = update_my_data(my_data)
|
|
else:
|
|
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
|
my_data = {}
|
|
|
|
values = [file_path]
|
|
for key, value in parameters:
|
|
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
|
if not key in ['ask_for_file', 'file_path']:
|
|
values.append(my_data.get(key, value))
|
|
return tuple(values)
|
|
|
|
|
|
def train_model(
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
logging_dir,
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
max_resolution,
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
seed,
|
|
num_cpu_threads_per_process,
|
|
cache_latents,
|
|
caption_extension,
|
|
enable_bucket,
|
|
gradient_checkpointing,
|
|
full_fp16,
|
|
no_token_padding,
|
|
stop_text_encoder_training_pct,
|
|
# use_8bit_adam,
|
|
xformers,
|
|
save_model_as,
|
|
shuffle_caption,
|
|
save_state,
|
|
resume,
|
|
prior_loss_weight,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
vae,
|
|
output_name,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
mem_eff_attn,
|
|
gradient_accumulation_steps,
|
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
|
token_string,
|
|
init_word,
|
|
num_vectors_per_token,
|
|
max_train_steps,
|
|
weights,
|
|
template,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
optimizer,
|
|
optimizer_args,
|
|
noise_offset,
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
additional_parameters,vae_batch_size,
|
|
):
|
|
if pretrained_model_name_or_path == '':
|
|
show_message_box('Source model information is missing')
|
|
return
|
|
|
|
if train_data_dir == '':
|
|
show_message_box('Image folder path is missing')
|
|
return
|
|
|
|
if not os.path.exists(train_data_dir):
|
|
show_message_box('Image folder does not exist')
|
|
return
|
|
|
|
if reg_data_dir != '':
|
|
if not os.path.exists(reg_data_dir):
|
|
show_message_box('Regularisation folder does not exist')
|
|
return
|
|
|
|
if output_dir == '':
|
|
show_message_box('Output folder path is missing')
|
|
return
|
|
|
|
if token_string == '':
|
|
show_message_box('Token string is missing')
|
|
return
|
|
|
|
if init_word == '':
|
|
show_message_box('Init word is missing')
|
|
return
|
|
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
|
|
if check_if_model_exist(output_name, output_dir, save_model_as):
|
|
return
|
|
|
|
# Get a list of all subfolders in train_data_dir
|
|
subfolders = [
|
|
f
|
|
for f in os.listdir(train_data_dir)
|
|
if os.path.isdir(os.path.join(train_data_dir, f))
|
|
]
|
|
|
|
total_steps = 0
|
|
|
|
# Loop through each subfolder and extract the number of repeats
|
|
for folder in subfolders:
|
|
# Extract the number of repeats from the folder name
|
|
repeats = int(folder.split('_')[0])
|
|
|
|
# Count the number of images in the folder
|
|
num_images = len(
|
|
[
|
|
f
|
|
for f in os.listdir(os.path.join(train_data_dir, folder))
|
|
if f.endswith('.jpg')
|
|
or f.endswith('.jpeg')
|
|
or f.endswith('.png')
|
|
or f.endswith('.webp')
|
|
]
|
|
)
|
|
|
|
# Calculate the total number of steps for this folder
|
|
steps = repeats * num_images
|
|
total_steps += steps
|
|
|
|
# Print the result
|
|
print(f'Folder {folder}: {steps} steps')
|
|
|
|
# Print the result
|
|
# print(f"{total_steps} total steps")
|
|
|
|
if reg_data_dir == '':
|
|
reg_factor = 1
|
|
else:
|
|
print(
|
|
'Regularisation images are used... Will double the number of steps required...'
|
|
)
|
|
reg_factor = 2
|
|
|
|
# calculate max_train_steps
|
|
if max_train_steps == '':
|
|
max_train_steps = int(
|
|
math.ceil(
|
|
float(total_steps)
|
|
/ int(train_batch_size)
|
|
* int(epoch)
|
|
* int(reg_factor)
|
|
)
|
|
)
|
|
else:
|
|
max_train_steps = int(max_train_steps)
|
|
|
|
print(f'max_train_steps = {max_train_steps}')
|
|
|
|
# calculate stop encoder training
|
|
if stop_text_encoder_training_pct == None:
|
|
stop_text_encoder_training = 0
|
|
else:
|
|
stop_text_encoder_training = math.ceil(
|
|
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
|
)
|
|
print(f'stop_text_encoder_training = {stop_text_encoder_training}')
|
|
|
|
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
|
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
|
|
|
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_textual_inversion.py"'
|
|
if v2:
|
|
run_cmd += ' --v2'
|
|
if v_parameterization:
|
|
run_cmd += ' --v_parameterization'
|
|
if enable_bucket:
|
|
run_cmd += ' --enable_bucket'
|
|
if no_token_padding:
|
|
run_cmd += ' --no_token_padding'
|
|
run_cmd += (
|
|
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
|
)
|
|
run_cmd += f' --train_data_dir="{train_data_dir}"'
|
|
if len(reg_data_dir):
|
|
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
|
|
run_cmd += f' --resolution={max_resolution}'
|
|
run_cmd += f' --output_dir="{output_dir}"'
|
|
run_cmd += f' --logging_dir="{logging_dir}"'
|
|
if not stop_text_encoder_training == 0:
|
|
run_cmd += (
|
|
f' --stop_text_encoder_training={stop_text_encoder_training}'
|
|
)
|
|
if not save_model_as == 'same as source model':
|
|
run_cmd += f' --save_model_as={save_model_as}'
|
|
# if not resume == '':
|
|
# run_cmd += f' --resume={resume}'
|
|
if not float(prior_loss_weight) == 1.0:
|
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
|
if not vae == '':
|
|
run_cmd += f' --vae="{vae}"'
|
|
if not output_name == '':
|
|
run_cmd += f' --output_name="{output_name}"'
|
|
if int(max_token_length) > 75:
|
|
run_cmd += f' --max_token_length={max_token_length}'
|
|
if not max_train_epochs == '':
|
|
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
|
if not max_data_loader_n_workers == '':
|
|
run_cmd += (
|
|
f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
|
)
|
|
if int(gradient_accumulation_steps) > 1:
|
|
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
|
|
|
run_cmd += run_cmd_training(
|
|
learning_rate=learning_rate,
|
|
lr_scheduler=lr_scheduler,
|
|
lr_warmup_steps=lr_warmup_steps,
|
|
train_batch_size=train_batch_size,
|
|
max_train_steps=max_train_steps,
|
|
save_every_n_epochs=save_every_n_epochs,
|
|
mixed_precision=mixed_precision,
|
|
save_precision=save_precision,
|
|
seed=seed,
|
|
caption_extension=caption_extension,
|
|
cache_latents=cache_latents,
|
|
optimizer=optimizer,
|
|
optimizer_args=optimizer_args,
|
|
)
|
|
|
|
run_cmd += run_cmd_advanced_training(
|
|
max_train_epochs=max_train_epochs,
|
|
max_data_loader_n_workers=max_data_loader_n_workers,
|
|
max_token_length=max_token_length,
|
|
resume=resume,
|
|
save_state=save_state,
|
|
mem_eff_attn=mem_eff_attn,
|
|
clip_skip=clip_skip,
|
|
flip_aug=flip_aug,
|
|
color_aug=color_aug,
|
|
shuffle_caption=shuffle_caption,
|
|
gradient_checkpointing=gradient_checkpointing,
|
|
full_fp16=full_fp16,
|
|
xformers=xformers,
|
|
# use_8bit_adam=use_8bit_adam,
|
|
keep_tokens=keep_tokens,
|
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
|
bucket_no_upscale=bucket_no_upscale,
|
|
random_crop=random_crop,
|
|
bucket_reso_steps=bucket_reso_steps,
|
|
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
|
caption_dropout_rate=caption_dropout_rate,
|
|
noise_offset=noise_offset,
|
|
additional_parameters=additional_parameters,
|
|
vae_batch_size=vae_batch_size,
|
|
)
|
|
run_cmd += f' --token_string="{token_string}"'
|
|
run_cmd += f' --init_word="{init_word}"'
|
|
run_cmd += f' --num_vectors_per_token={num_vectors_per_token}'
|
|
if not weights == '':
|
|
run_cmd += f' --weights="{weights}"'
|
|
if template == 'object template':
|
|
run_cmd += f' --use_object_template'
|
|
elif template == 'style template':
|
|
run_cmd += f' --use_style_template'
|
|
|
|
run_cmd += run_cmd_sample(
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
output_dir,
|
|
)
|
|
|
|
print(run_cmd)
|
|
|
|
# Run the command
|
|
if os.name == 'posix':
|
|
os.system(run_cmd)
|
|
else:
|
|
subprocess.run(run_cmd)
|
|
|
|
# check if output_dir/last is a folder... therefore it is a diffuser model
|
|
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
|
|
|
|
if not last_dir.is_dir():
|
|
# Copy inference model for v2 if required
|
|
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
|
|
|
|
|
def ti_tab(
|
|
train_data_dir=gr.Textbox(),
|
|
reg_data_dir=gr.Textbox(),
|
|
output_dir=gr.Textbox(),
|
|
logging_dir=gr.Textbox(),
|
|
):
|
|
dummy_db_true = gr.Label(value=True, visible=False)
|
|
dummy_db_false = gr.Label(value=False, visible=False)
|
|
gr.Markdown('Train a TI using kohya textual inversion python code...')
|
|
(
|
|
button_open_config,
|
|
button_save_config,
|
|
button_save_as_config,
|
|
config_file_name,
|
|
button_load_config,
|
|
) = gradio_config()
|
|
|
|
(
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
save_model_as,
|
|
model_list,
|
|
) = gradio_source_model(save_model_as_choices = [
|
|
'ckpt',
|
|
'safetensors',
|
|
])
|
|
|
|
with gr.Tab('Folders'):
|
|
with gr.Row():
|
|
train_data_dir = gr.Textbox(
|
|
label='Image folder',
|
|
placeholder='Folder where the training folders containing the images are located',
|
|
)
|
|
train_data_dir_input_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
train_data_dir_input_folder.click(
|
|
get_folder_path,
|
|
outputs=train_data_dir,
|
|
show_progress=False,
|
|
)
|
|
reg_data_dir = gr.Textbox(
|
|
label='Regularisation folder',
|
|
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
|
)
|
|
reg_data_dir_input_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
reg_data_dir_input_folder.click(
|
|
get_folder_path,
|
|
outputs=reg_data_dir,
|
|
show_progress=False,
|
|
)
|
|
with gr.Row():
|
|
output_dir = gr.Textbox(
|
|
label='Model output folder',
|
|
placeholder='Folder to output trained model',
|
|
)
|
|
output_dir_input_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
output_dir_input_folder.click(
|
|
get_folder_path,
|
|
outputs=output_dir,
|
|
show_progress=False,
|
|
)
|
|
logging_dir = gr.Textbox(
|
|
label='Logging folder',
|
|
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
|
)
|
|
logging_dir_input_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
logging_dir_input_folder.click(
|
|
get_folder_path,
|
|
outputs=logging_dir,
|
|
show_progress=False,
|
|
)
|
|
with gr.Row():
|
|
output_name = gr.Textbox(
|
|
label='Model output name',
|
|
placeholder='Name of the model to output',
|
|
value='last',
|
|
interactive=True,
|
|
)
|
|
train_data_dir.change(
|
|
remove_doublequote,
|
|
inputs=[train_data_dir],
|
|
outputs=[train_data_dir],
|
|
)
|
|
reg_data_dir.change(
|
|
remove_doublequote,
|
|
inputs=[reg_data_dir],
|
|
outputs=[reg_data_dir],
|
|
)
|
|
output_dir.change(
|
|
remove_doublequote,
|
|
inputs=[output_dir],
|
|
outputs=[output_dir],
|
|
)
|
|
logging_dir.change(
|
|
remove_doublequote,
|
|
inputs=[logging_dir],
|
|
outputs=[logging_dir],
|
|
)
|
|
with gr.Tab('Training parameters'):
|
|
with gr.Row():
|
|
weights = gr.Textbox(
|
|
label='Resume TI training',
|
|
placeholder='(Optional) Path to existing TI embeding file to keep training',
|
|
)
|
|
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
|
weights_file_input.click(
|
|
get_file_path,
|
|
outputs=weights,
|
|
show_progress=False,
|
|
)
|
|
with gr.Row():
|
|
token_string = gr.Textbox(
|
|
label='Token string',
|
|
placeholder='eg: cat',
|
|
)
|
|
init_word = gr.Textbox(
|
|
label='Init word',
|
|
value='*',
|
|
)
|
|
num_vectors_per_token = gr.Slider(
|
|
minimum=1,
|
|
maximum=75,
|
|
value=1,
|
|
step=1,
|
|
label='Vectors',
|
|
)
|
|
max_train_steps = gr.Textbox(
|
|
label='Max train steps',
|
|
placeholder='(Optional) Maximum number of steps',
|
|
)
|
|
template = gr.Dropdown(
|
|
label='Template',
|
|
choices=[
|
|
'caption',
|
|
'object template',
|
|
'style template',
|
|
],
|
|
value='caption',
|
|
)
|
|
(
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
num_cpu_threads_per_process,
|
|
seed,
|
|
caption_extension,
|
|
cache_latents,
|
|
optimizer,
|
|
optimizer_args,
|
|
) = gradio_training(
|
|
learning_rate_value='1e-5',
|
|
lr_scheduler_value='cosine',
|
|
lr_warmup_value='10',
|
|
)
|
|
with gr.Row():
|
|
max_resolution = gr.Textbox(
|
|
label='Max resolution',
|
|
value='512,512',
|
|
placeholder='512,512',
|
|
)
|
|
stop_text_encoder_training = gr.Slider(
|
|
minimum=0,
|
|
maximum=100,
|
|
value=0,
|
|
step=1,
|
|
label='Stop text encoder training',
|
|
)
|
|
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
|
with gr.Accordion('Advanced Configuration', open=False):
|
|
with gr.Row():
|
|
no_token_padding = gr.Checkbox(
|
|
label='No token padding', value=False
|
|
)
|
|
gradient_accumulation_steps = gr.Number(
|
|
label='Gradient accumulate steps', value='1'
|
|
)
|
|
with gr.Row():
|
|
prior_loss_weight = gr.Number(
|
|
label='Prior loss weight', value=1.0
|
|
)
|
|
vae = gr.Textbox(
|
|
label='VAE',
|
|
placeholder='(Optiona) path to checkpoint of vae to replace for training',
|
|
)
|
|
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
|
vae_button.click(
|
|
get_any_file_path,
|
|
outputs=vae,
|
|
show_progress=False,
|
|
)
|
|
(
|
|
# use_8bit_adam,
|
|
xformers,
|
|
full_fp16,
|
|
gradient_checkpointing,
|
|
shuffle_caption,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
mem_eff_attn,
|
|
save_state,
|
|
resume,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
noise_offset,
|
|
additional_parameters,
|
|
vae_batch_size,
|
|
) = gradio_advanced_training()
|
|
color_aug.change(
|
|
color_aug_changed,
|
|
inputs=[color_aug],
|
|
outputs=[cache_latents],
|
|
)
|
|
|
|
(
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
) = sample_gradio_config()
|
|
|
|
with gr.Tab('Tools'):
|
|
gr.Markdown(
|
|
'This section provide Dreambooth tools to help setup your dataset...'
|
|
)
|
|
gradio_dreambooth_folder_creation_tab(
|
|
train_data_dir_input=train_data_dir,
|
|
reg_data_dir_input=reg_data_dir,
|
|
output_dir_input=output_dir,
|
|
logging_dir_input=logging_dir,
|
|
)
|
|
|
|
button_run = gr.Button('Train model', variant='primary')
|
|
|
|
# Setup gradio tensorboard buttons
|
|
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
|
|
|
button_start_tensorboard.click(
|
|
start_tensorboard,
|
|
inputs=logging_dir,
|
|
show_progress=False,
|
|
)
|
|
|
|
button_stop_tensorboard.click(
|
|
stop_tensorboard,
|
|
show_progress=False,
|
|
)
|
|
|
|
settings_list = [
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
logging_dir,
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
max_resolution,
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
seed,
|
|
num_cpu_threads_per_process,
|
|
cache_latents,
|
|
caption_extension,
|
|
enable_bucket,
|
|
gradient_checkpointing,
|
|
full_fp16,
|
|
no_token_padding,
|
|
stop_text_encoder_training,
|
|
# use_8bit_adam,
|
|
xformers,
|
|
save_model_as,
|
|
shuffle_caption,
|
|
save_state,
|
|
resume,
|
|
prior_loss_weight,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
vae,
|
|
output_name,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
mem_eff_attn,
|
|
gradient_accumulation_steps,
|
|
model_list,
|
|
token_string,
|
|
init_word,
|
|
num_vectors_per_token,
|
|
max_train_steps,
|
|
weights,
|
|
template,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
optimizer,
|
|
optimizer_args,
|
|
noise_offset,
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
additional_parameters,
|
|
vae_batch_size,
|
|
]
|
|
|
|
button_open_config.click(
|
|
open_configuration,
|
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
|
outputs=[config_file_name] + settings_list,
|
|
show_progress=False,
|
|
)
|
|
|
|
button_load_config.click(
|
|
open_configuration,
|
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
|
outputs=[config_file_name] + settings_list,
|
|
show_progress=False,
|
|
)
|
|
|
|
button_save_config.click(
|
|
save_configuration,
|
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
|
outputs=[config_file_name],
|
|
show_progress=False,
|
|
)
|
|
|
|
button_save_as_config.click(
|
|
save_configuration,
|
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
|
outputs=[config_file_name],
|
|
show_progress=False,
|
|
)
|
|
|
|
button_run.click(
|
|
train_model,
|
|
inputs=settings_list,
|
|
show_progress=False,
|
|
)
|
|
|
|
return (
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
logging_dir,
|
|
)
|
|
|
|
|
|
def UI(**kwargs):
|
|
css = ''
|
|
|
|
if os.path.exists('./style.css'):
|
|
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
|
print('Load CSS...')
|
|
css += file.read() + '\n'
|
|
|
|
interface = gr.Blocks(css=css)
|
|
|
|
with interface:
|
|
with gr.Tab('Dreambooth TI'):
|
|
(
|
|
train_data_dir_input,
|
|
reg_data_dir_input,
|
|
output_dir_input,
|
|
logging_dir_input,
|
|
) = ti_tab()
|
|
with gr.Tab('Utilities'):
|
|
utilities_tab(
|
|
train_data_dir_input=train_data_dir_input,
|
|
reg_data_dir_input=reg_data_dir_input,
|
|
output_dir_input=output_dir_input,
|
|
logging_dir_input=logging_dir_input,
|
|
enable_copy_info_button=True,
|
|
)
|
|
|
|
# Show the interface
|
|
launch_kwargs = {}
|
|
if not kwargs.get('username', None) == '':
|
|
launch_kwargs['auth'] = (
|
|
kwargs.get('username', None),
|
|
kwargs.get('password', None),
|
|
)
|
|
if kwargs.get('server_port', 0) > 0:
|
|
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
|
if kwargs.get('inbrowser', False):
|
|
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
|
print(launch_kwargs)
|
|
interface.launch(**launch_kwargs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--username', type=str, default='', help='Username for authentication'
|
|
)
|
|
parser.add_argument(
|
|
'--password', type=str, default='', help='Password for authentication'
|
|
)
|
|
parser.add_argument(
|
|
'--server_port',
|
|
type=int,
|
|
default=0,
|
|
help='Port to run the server listener on',
|
|
)
|
|
parser.add_argument(
|
|
'--inbrowser', action='store_true', help='Open in browser'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
UI(
|
|
username=args.username,
|
|
password=args.password,
|
|
inbrowser=args.inbrowser,
|
|
server_port=args.server_port,
|
|
)
|