7b5639cff5
This is a massive WIP and should not be trusted or used right now. However, major milestones have been crossed. Both message boxes and file dialogs are now properly subprocessed and work on macOS. I think by extension, it may work on runpod environments as well, but that remains to be tested.
927 lines
27 KiB
Python
927 lines
27 KiB
Python
# v1: initial release
|
|
# v2: add open and save folder icons
|
|
# v3: Add new Utilities tab for Dreambooth folder preparation
|
|
# v3.1: Adding captionning of images to utilities
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import pathlib
|
|
import subprocess
|
|
|
|
import gradio as gr
|
|
|
|
from library.common_gui import (
|
|
get_folder_path,
|
|
remove_doublequote,
|
|
get_file_path,
|
|
get_any_file_path,
|
|
get_saveasfile_path,
|
|
color_aug_changed,
|
|
save_inference_file,
|
|
gradio_advanced_training,
|
|
run_cmd_advanced_training,
|
|
run_cmd_training,
|
|
gradio_training,
|
|
gradio_config,
|
|
gradio_source_model,
|
|
# set_legacy_8bitadam,
|
|
update_my_data,
|
|
check_if_model_exist, is_valid_config, show_message_box,
|
|
)
|
|
from library.dreambooth_folder_creation_gui import (
|
|
gradio_dreambooth_folder_creation_tab,
|
|
)
|
|
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
|
from library.tensorboard_gui import (
|
|
gradio_tensorboard,
|
|
start_tensorboard,
|
|
stop_tensorboard,
|
|
)
|
|
from library.utilities import utilities_tab
|
|
|
|
folder_symbol = '\U0001f4c2' # 📂
|
|
refresh_symbol = '\U0001f504' # 🔄
|
|
save_style_symbol = '\U0001f4be' # 💾
|
|
document_symbol = '\U0001F4C4' # 📄
|
|
|
|
|
|
def save_configuration(
|
|
save_as,
|
|
file_path,
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
logging_dir,
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
max_resolution,
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
seed,
|
|
num_cpu_threads_per_process,
|
|
cache_latents,
|
|
caption_extension,
|
|
enable_bucket,
|
|
gradient_checkpointing,
|
|
full_fp16,
|
|
no_token_padding,
|
|
stop_text_encoder_training,
|
|
# use_8bit_adam,
|
|
xformers,
|
|
save_model_as,
|
|
shuffle_caption,
|
|
save_state,
|
|
resume,
|
|
prior_loss_weight,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
vae,
|
|
output_name,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
mem_eff_attn,
|
|
gradient_accumulation_steps,
|
|
model_list,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
optimizer,
|
|
optimizer_args,
|
|
noise_offset,
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
additional_parameters,
|
|
vae_batch_size,
|
|
):
|
|
# Get list of function parameters and values
|
|
parameters = list(locals().items())
|
|
|
|
original_file_path = file_path
|
|
|
|
save_as_bool = True if save_as.get('label') == 'True' else False
|
|
|
|
if save_as_bool:
|
|
print('Save as...')
|
|
file_path = get_saveasfile_path(file_path)
|
|
else:
|
|
print('Save...')
|
|
if file_path == None or file_path == '':
|
|
file_path = get_saveasfile_path(file_path)
|
|
|
|
# print(file_path)
|
|
|
|
if file_path == None or file_path == '':
|
|
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
|
|
|
# Return the values of the variables as a dictionary
|
|
variables = {
|
|
name: value
|
|
for name, value in parameters # locals().items()
|
|
if name
|
|
not in [
|
|
'file_path',
|
|
'save_as',
|
|
]
|
|
}
|
|
|
|
# Extract the destination directory from the file path
|
|
destination_directory = os.path.dirname(file_path)
|
|
|
|
# Create the destination directory if it doesn't exist
|
|
if not os.path.exists(destination_directory):
|
|
os.makedirs(destination_directory)
|
|
|
|
# Save the data to the selected file
|
|
with open(file_path, 'w') as file:
|
|
json.dump(variables, file, indent=2)
|
|
|
|
return file_path
|
|
|
|
|
|
def open_configuration(
|
|
ask_for_file,
|
|
file_path,
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
logging_dir,
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
max_resolution,
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
seed,
|
|
num_cpu_threads_per_process,
|
|
cache_latents,
|
|
caption_extension,
|
|
enable_bucket,
|
|
gradient_checkpointing,
|
|
full_fp16,
|
|
no_token_padding,
|
|
stop_text_encoder_training,
|
|
# use_8bit_adam,
|
|
xformers,
|
|
save_model_as,
|
|
shuffle_caption,
|
|
save_state,
|
|
resume,
|
|
prior_loss_weight,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
vae,
|
|
output_name,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
mem_eff_attn,
|
|
gradient_accumulation_steps,
|
|
model_list,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
optimizer,
|
|
optimizer_args,
|
|
noise_offset,
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
additional_parameters,
|
|
vae_batch_size,
|
|
):
|
|
# Get list of function parameters and values
|
|
parameters = list(locals().items())
|
|
|
|
ask_for_file = True if ask_for_file.get('label') == 'True' else False
|
|
|
|
original_file_path = file_path
|
|
|
|
if ask_for_file:
|
|
file_path = get_file_path(file_path, filedialog_type="json")
|
|
|
|
if not file_path == '' and file_path is not None:
|
|
with open(file_path, 'r') as f:
|
|
my_data = json.load(f)
|
|
if is_valid_config(my_data):
|
|
print('Loading config...')
|
|
my_data = update_my_data(my_data)
|
|
else:
|
|
print("Invalid configuration file.")
|
|
my_data = {}
|
|
show_message_box("Invalid configuration file.")
|
|
else:
|
|
file_path = original_file_path
|
|
my_data = {}
|
|
|
|
values = [file_path]
|
|
for key, value in parameters:
|
|
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
|
if not key in ['ask_for_file', 'file_path']:
|
|
values.append(my_data.get(key, value))
|
|
return tuple(values)
|
|
|
|
|
|
def train_model(
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
logging_dir,
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
max_resolution,
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
seed,
|
|
num_cpu_threads_per_process,
|
|
cache_latents,
|
|
caption_extension,
|
|
enable_bucket,
|
|
gradient_checkpointing,
|
|
full_fp16,
|
|
no_token_padding,
|
|
stop_text_encoder_training_pct,
|
|
# use_8bit_adam,
|
|
xformers,
|
|
save_model_as,
|
|
shuffle_caption,
|
|
save_state,
|
|
resume,
|
|
prior_loss_weight,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
vae,
|
|
output_name,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
mem_eff_attn,
|
|
gradient_accumulation_steps,
|
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
optimizer,
|
|
optimizer_args,
|
|
noise_offset,
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
additional_parameters,
|
|
vae_batch_size,
|
|
):
|
|
if pretrained_model_name_or_path == '':
|
|
show_message_box('Source model information is missing')
|
|
return
|
|
|
|
if train_data_dir == '':
|
|
show_message_box('Image folder path is missing')
|
|
return
|
|
|
|
if not os.path.exists(train_data_dir):
|
|
show_message_box('Image folder does not exist')
|
|
return
|
|
|
|
if reg_data_dir != '':
|
|
if not os.path.exists(reg_data_dir):
|
|
show_message_box('Regularisation folder does not exist')
|
|
return
|
|
|
|
if output_dir == '':
|
|
show_message_box('Output folder path is missing')
|
|
return
|
|
|
|
if check_if_model_exist(output_name, output_dir, save_model_as):
|
|
return
|
|
|
|
# Get a list of all subfolders in train_data_dir, excluding hidden folders
|
|
subfolders = [
|
|
f
|
|
for f in os.listdir(train_data_dir)
|
|
if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.')
|
|
]
|
|
|
|
# Check if subfolders are present. If not let the user know and return
|
|
if not subfolders:
|
|
print('\033[33mNo subfolders were found in', train_data_dir, ' can\'t train\...033[0m')
|
|
return
|
|
|
|
total_steps = 0
|
|
|
|
# Loop through each subfolder and extract the number of repeats
|
|
for folder in subfolders:
|
|
# Extract the number of repeats from the folder name
|
|
try:
|
|
repeats = int(folder.split('_')[0])
|
|
except ValueError:
|
|
print('\033[33mSubfolder', folder,
|
|
'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m')
|
|
continue
|
|
|
|
# Count the number of images in the folder
|
|
num_images = len(
|
|
[
|
|
f
|
|
for f in os.listdir(os.path.join(train_data_dir, folder))
|
|
if f.endswith('.jpg')
|
|
or f.endswith('.jpeg')
|
|
or f.endswith('.png')
|
|
or f.endswith('.webp')
|
|
]
|
|
)
|
|
|
|
if num_images == 0:
|
|
print(f'{folder} folder contain no images, skipping...')
|
|
else:
|
|
# Calculate the total number of steps for this folder
|
|
steps = repeats * num_images
|
|
total_steps += steps
|
|
|
|
# Print the result
|
|
print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
|
|
|
|
if total_steps == 0:
|
|
print('\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m')
|
|
return
|
|
|
|
# Print the result
|
|
# print(f"{total_steps} total steps")
|
|
|
|
if reg_data_dir == '':
|
|
reg_factor = 1
|
|
else:
|
|
print('\033[94mRegularisation images are used... Will double the number of steps required...\033[0m')
|
|
reg_factor = 2
|
|
|
|
# calculate max_train_steps
|
|
max_train_steps = int(
|
|
math.ceil(
|
|
float(total_steps)
|
|
/ int(train_batch_size)
|
|
* int(epoch)
|
|
* int(reg_factor)
|
|
)
|
|
)
|
|
print(f'max_train_steps = {max_train_steps}')
|
|
|
|
# calculate stop encoder training
|
|
if int(stop_text_encoder_training_pct) == -1:
|
|
stop_text_encoder_training = -1
|
|
elif stop_text_encoder_training_pct == None:
|
|
stop_text_encoder_training = 0
|
|
else:
|
|
stop_text_encoder_training = math.ceil(
|
|
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
|
)
|
|
print(f'stop_text_encoder_training = {stop_text_encoder_training}')
|
|
|
|
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
|
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
|
|
|
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
|
|
if v2:
|
|
run_cmd += ' --v2'
|
|
if v_parameterization:
|
|
run_cmd += ' --v_parameterization'
|
|
if enable_bucket:
|
|
run_cmd += ' --enable_bucket'
|
|
if no_token_padding:
|
|
run_cmd += ' --no_token_padding'
|
|
run_cmd += (
|
|
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
|
)
|
|
run_cmd += f' --train_data_dir="{train_data_dir}"'
|
|
if len(reg_data_dir):
|
|
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
|
|
run_cmd += f' --resolution={max_resolution}'
|
|
run_cmd += f' --output_dir="{output_dir}"'
|
|
run_cmd += f' --logging_dir="{logging_dir}"'
|
|
if not stop_text_encoder_training == 0:
|
|
run_cmd += (
|
|
f' --stop_text_encoder_training={stop_text_encoder_training}'
|
|
)
|
|
if not save_model_as == 'same as source model':
|
|
run_cmd += f' --save_model_as={save_model_as}'
|
|
# if not resume == '':
|
|
# run_cmd += f' --resume={resume}'
|
|
if not float(prior_loss_weight) == 1.0:
|
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
|
if not vae == '':
|
|
run_cmd += f' --vae="{vae}"'
|
|
if not output_name == '':
|
|
run_cmd += f' --output_name="{output_name}"'
|
|
if int(max_token_length) > 75:
|
|
run_cmd += f' --max_token_length={max_token_length}'
|
|
if not max_train_epochs == '':
|
|
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
|
if not max_data_loader_n_workers == '':
|
|
run_cmd += (
|
|
f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
|
)
|
|
if int(gradient_accumulation_steps) > 1:
|
|
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
|
|
|
run_cmd += run_cmd_training(
|
|
learning_rate=learning_rate,
|
|
lr_scheduler=lr_scheduler,
|
|
lr_warmup_steps=lr_warmup_steps,
|
|
train_batch_size=train_batch_size,
|
|
max_train_steps=max_train_steps,
|
|
save_every_n_epochs=save_every_n_epochs,
|
|
mixed_precision=mixed_precision,
|
|
save_precision=save_precision,
|
|
seed=seed,
|
|
caption_extension=caption_extension,
|
|
cache_latents=cache_latents,
|
|
optimizer=optimizer,
|
|
optimizer_args=optimizer_args,
|
|
)
|
|
|
|
run_cmd += run_cmd_advanced_training(
|
|
max_train_epochs=max_train_epochs,
|
|
max_data_loader_n_workers=max_data_loader_n_workers,
|
|
max_token_length=max_token_length,
|
|
resume=resume,
|
|
save_state=save_state,
|
|
mem_eff_attn=mem_eff_attn,
|
|
clip_skip=clip_skip,
|
|
flip_aug=flip_aug,
|
|
color_aug=color_aug,
|
|
shuffle_caption=shuffle_caption,
|
|
gradient_checkpointing=gradient_checkpointing,
|
|
full_fp16=full_fp16,
|
|
xformers=xformers,
|
|
# use_8bit_adam=use_8bit_adam,
|
|
keep_tokens=keep_tokens,
|
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
|
bucket_no_upscale=bucket_no_upscale,
|
|
random_crop=random_crop,
|
|
bucket_reso_steps=bucket_reso_steps,
|
|
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
|
caption_dropout_rate=caption_dropout_rate,
|
|
noise_offset=noise_offset,
|
|
additional_parameters=additional_parameters,
|
|
vae_batch_size=vae_batch_size,
|
|
)
|
|
|
|
run_cmd += run_cmd_sample(
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
output_dir,
|
|
)
|
|
|
|
print(run_cmd)
|
|
|
|
# Run the command
|
|
if os.name == 'posix':
|
|
os.system(run_cmd)
|
|
else:
|
|
subprocess.run(run_cmd)
|
|
|
|
# check if output_dir/last is a folder... therefore it is a diffuser model
|
|
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
|
|
|
|
if not last_dir.is_dir():
|
|
# Copy inference model for v2 if required
|
|
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
|
|
|
|
|
def dreambooth_tab(
|
|
train_data_dir=gr.Textbox(),
|
|
reg_data_dir=gr.Textbox(),
|
|
output_dir=gr.Textbox(),
|
|
logging_dir=gr.Textbox(),
|
|
):
|
|
dummy_db_true = gr.Label(value=True, visible=False)
|
|
dummy_db_false = gr.Label(value=False, visible=False)
|
|
gr.Markdown('Train a custom model using kohya dreambooth python code...')
|
|
(
|
|
button_open_config,
|
|
button_save_config,
|
|
button_save_as_config,
|
|
config_file_name,
|
|
button_load_config,
|
|
) = gradio_config()
|
|
|
|
(
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
save_model_as,
|
|
model_list,
|
|
) = gradio_source_model()
|
|
|
|
with gr.Tab('Folders'):
|
|
with gr.Row():
|
|
train_data_dir = gr.Textbox(
|
|
label='Image folder',
|
|
placeholder='Folder where the training folders containing the images are located',
|
|
)
|
|
train_data_dir_input_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
train_data_dir_input_folder.click(
|
|
get_folder_path,
|
|
outputs=train_data_dir,
|
|
show_progress=False,
|
|
)
|
|
reg_data_dir = gr.Textbox(
|
|
label='Regularisation folder',
|
|
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
|
)
|
|
reg_data_dir_input_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
reg_data_dir_input_folder.click(
|
|
get_folder_path,
|
|
outputs=reg_data_dir,
|
|
show_progress=False,
|
|
)
|
|
with gr.Row():
|
|
output_dir = gr.Textbox(
|
|
label='Model output folder',
|
|
placeholder='Folder to output trained model',
|
|
)
|
|
output_dir_input_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
output_dir_input_folder.click(get_folder_path, outputs=output_dir)
|
|
logging_dir = gr.Textbox(
|
|
label='Logging folder',
|
|
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
|
)
|
|
logging_dir_input_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
logging_dir_input_folder.click(
|
|
get_folder_path,
|
|
outputs=logging_dir,
|
|
show_progress=False,
|
|
)
|
|
with gr.Row():
|
|
output_name = gr.Textbox(
|
|
label='Model output name',
|
|
placeholder='Name of the model to output',
|
|
value='last',
|
|
interactive=True,
|
|
)
|
|
train_data_dir.change(
|
|
remove_doublequote,
|
|
inputs=[train_data_dir],
|
|
outputs=[train_data_dir],
|
|
)
|
|
reg_data_dir.change(
|
|
remove_doublequote,
|
|
inputs=[reg_data_dir],
|
|
outputs=[reg_data_dir],
|
|
)
|
|
output_dir.change(
|
|
remove_doublequote,
|
|
inputs=[output_dir],
|
|
outputs=[output_dir],
|
|
)
|
|
logging_dir.change(
|
|
remove_doublequote,
|
|
inputs=[logging_dir],
|
|
outputs=[logging_dir],
|
|
)
|
|
with gr.Tab('Training parameters'):
|
|
(
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
num_cpu_threads_per_process,
|
|
seed,
|
|
caption_extension,
|
|
cache_latents,
|
|
optimizer,
|
|
optimizer_args,
|
|
) = gradio_training(
|
|
learning_rate_value='1e-5',
|
|
lr_scheduler_value='cosine',
|
|
lr_warmup_value='10',
|
|
)
|
|
with gr.Row():
|
|
max_resolution = gr.Textbox(
|
|
label='Max resolution',
|
|
value='512,512',
|
|
placeholder='512,512',
|
|
)
|
|
stop_text_encoder_training = gr.Slider(
|
|
minimum=-1,
|
|
maximum=100,
|
|
value=0,
|
|
step=1,
|
|
label='Stop text encoder training',
|
|
)
|
|
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
|
with gr.Accordion('Advanced Configuration', open=False):
|
|
with gr.Row():
|
|
no_token_padding = gr.Checkbox(
|
|
label='No token padding', value=False
|
|
)
|
|
gradient_accumulation_steps = gr.Number(
|
|
label='Gradient accumulate steps', value='1'
|
|
)
|
|
with gr.Row():
|
|
prior_loss_weight = gr.Number(
|
|
label='Prior loss weight', value=1.0
|
|
)
|
|
vae = gr.Textbox(
|
|
label='VAE',
|
|
placeholder='(Optiona) path to checkpoint of vae to replace for training',
|
|
)
|
|
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
|
vae_button.click(
|
|
get_any_file_path,
|
|
outputs=vae,
|
|
show_progress=False,
|
|
)
|
|
(
|
|
# use_8bit_adam,
|
|
xformers,
|
|
full_fp16,
|
|
gradient_checkpointing,
|
|
shuffle_caption,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
mem_eff_attn,
|
|
save_state,
|
|
resume,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
noise_offset,
|
|
additional_parameters,
|
|
vae_batch_size,
|
|
) = gradio_advanced_training()
|
|
color_aug.change(
|
|
color_aug_changed,
|
|
inputs=[color_aug],
|
|
outputs=[cache_latents],
|
|
)
|
|
|
|
(
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
) = sample_gradio_config()
|
|
|
|
with gr.Tab('Tools'):
|
|
gr.Markdown(
|
|
'This section provide Dreambooth tools to help setup your dataset...'
|
|
)
|
|
gradio_dreambooth_folder_creation_tab(
|
|
train_data_dir_input=train_data_dir,
|
|
reg_data_dir_input=reg_data_dir,
|
|
output_dir_input=output_dir,
|
|
logging_dir_input=logging_dir,
|
|
)
|
|
|
|
button_run = gr.Button('Train model', variant='primary')
|
|
|
|
# Setup gradio tensorboard buttons
|
|
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
|
|
|
button_start_tensorboard.click(
|
|
start_tensorboard,
|
|
inputs=logging_dir,
|
|
show_progress=False,
|
|
)
|
|
|
|
button_stop_tensorboard.click(
|
|
stop_tensorboard,
|
|
show_progress=False,
|
|
)
|
|
|
|
settings_list = [
|
|
pretrained_model_name_or_path,
|
|
v2,
|
|
v_parameterization,
|
|
logging_dir,
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
max_resolution,
|
|
learning_rate,
|
|
lr_scheduler,
|
|
lr_warmup,
|
|
train_batch_size,
|
|
epoch,
|
|
save_every_n_epochs,
|
|
mixed_precision,
|
|
save_precision,
|
|
seed,
|
|
num_cpu_threads_per_process,
|
|
cache_latents,
|
|
caption_extension,
|
|
enable_bucket,
|
|
gradient_checkpointing,
|
|
full_fp16,
|
|
no_token_padding,
|
|
stop_text_encoder_training,
|
|
# use_8bit_adam,
|
|
xformers,
|
|
save_model_as,
|
|
shuffle_caption,
|
|
save_state,
|
|
resume,
|
|
prior_loss_weight,
|
|
color_aug,
|
|
flip_aug,
|
|
clip_skip,
|
|
vae,
|
|
output_name,
|
|
max_token_length,
|
|
max_train_epochs,
|
|
max_data_loader_n_workers,
|
|
mem_eff_attn,
|
|
gradient_accumulation_steps,
|
|
model_list,
|
|
keep_tokens,
|
|
persistent_data_loader_workers,
|
|
bucket_no_upscale,
|
|
random_crop,
|
|
bucket_reso_steps,
|
|
caption_dropout_every_n_epochs,
|
|
caption_dropout_rate,
|
|
optimizer,
|
|
optimizer_args,
|
|
noise_offset,
|
|
sample_every_n_steps,
|
|
sample_every_n_epochs,
|
|
sample_sampler,
|
|
sample_prompts,
|
|
additional_parameters,
|
|
vae_batch_size,
|
|
]
|
|
|
|
button_open_config.click(
|
|
open_configuration,
|
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
|
outputs=[config_file_name] + settings_list,
|
|
show_progress=False,
|
|
)
|
|
|
|
button_load_config.click(
|
|
open_configuration,
|
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
|
outputs=[config_file_name] + settings_list,
|
|
show_progress=False,
|
|
)
|
|
|
|
button_save_config.click(
|
|
save_configuration,
|
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
|
outputs=[config_file_name],
|
|
show_progress=False,
|
|
)
|
|
|
|
button_save_as_config.click(
|
|
save_configuration,
|
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
|
outputs=[config_file_name],
|
|
show_progress=False,
|
|
)
|
|
|
|
button_run.click(
|
|
train_model,
|
|
inputs=settings_list,
|
|
show_progress=False,
|
|
)
|
|
|
|
return (
|
|
train_data_dir,
|
|
reg_data_dir,
|
|
output_dir,
|
|
logging_dir,
|
|
)
|
|
|
|
|
|
def UI(**kwargs):
|
|
css = ''
|
|
|
|
if os.path.exists('./style.css'):
|
|
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
|
print('Load CSS...')
|
|
css += file.read() + '\n'
|
|
|
|
interface = gr.Blocks(css=css)
|
|
|
|
with interface:
|
|
with gr.Tab('Dreambooth'):
|
|
(
|
|
train_data_dir_input,
|
|
reg_data_dir_input,
|
|
output_dir_input,
|
|
logging_dir_input,
|
|
) = dreambooth_tab()
|
|
with gr.Tab('Utilities'):
|
|
utilities_tab(
|
|
train_data_dir_input=train_data_dir_input,
|
|
reg_data_dir_input=reg_data_dir_input,
|
|
output_dir_input=output_dir_input,
|
|
logging_dir_input=logging_dir_input,
|
|
enable_copy_info_button=True,
|
|
)
|
|
|
|
# Show the interface
|
|
launch_kwargs = {}
|
|
if not kwargs.get('username', None) == '':
|
|
launch_kwargs['auth'] = (
|
|
kwargs.get('username', None),
|
|
kwargs.get('password', None),
|
|
)
|
|
if kwargs.get('server_port', 0) > 0:
|
|
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
|
if kwargs.get('inbrowser', False):
|
|
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
|
print(launch_kwargs)
|
|
interface.launch(**launch_kwargs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--username', type=str, default='', help='Username for authentication'
|
|
)
|
|
parser.add_argument(
|
|
'--password', type=str, default='', help='Password for authentication'
|
|
)
|
|
parser.add_argument(
|
|
'--server_port',
|
|
type=int,
|
|
default=0,
|
|
help='Port to run the server listener on',
|
|
)
|
|
parser.add_argument(
|
|
'--inbrowser', action='store_true', help='Open in browser'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
UI(
|
|
username=args.username,
|
|
password=args.password,
|
|
inbrowser=args.inbrowser,
|
|
server_port=args.server_port,
|
|
)
|