2022-12-13 22:52:25 +00:00
# v1: initial release
2022-12-14 19:40:24 +00:00
# v2: add open and save folder icons
2022-12-15 23:19:35 +00:00
# v3: Add new Utilities tab for Dreambooth folder preparation
2022-12-13 22:52:25 +00:00
2022-12-13 14:20:25 +00:00
import gradio as gr
import json
import math
import os
import subprocess
import pathlib
import shutil
from glob import glob
from os . path import join
2022-12-14 19:40:24 +00:00
from easygui import fileopenbox , filesavebox , enterbox , diropenbox , msgbox
2022-12-13 14:20:25 +00:00
2022-12-15 12:48:29 +00:00
def save_configuration (
save_as ,
2022-12-13 14:20:25 +00:00
file_path ,
pretrained_model_name_or_path ,
v2 ,
2022-12-13 19:59:33 +00:00
v_parameterization ,
2022-12-13 14:20:25 +00:00
logging_dir ,
train_data_dir ,
reg_data_dir ,
output_dir ,
max_resolution ,
learning_rate ,
lr_scheduler ,
lr_warmup ,
train_batch_size ,
epoch ,
save_every_n_epochs ,
mixed_precision ,
save_precision ,
seed ,
num_cpu_threads_per_process ,
convert_to_safetensors ,
2022-12-13 16:07:32 +00:00
convert_to_ckpt ,
cache_latent ,
caption_extention ,
use_safetensors ,
2022-12-13 18:49:14 +00:00
enable_bucket ,
gradient_checkpointing ,
2022-12-14 02:21:59 +00:00
full_fp16 ,
no_token_padding ,
stop_text_encoder_training ,
use_8bit_adam ,
xformers ,
2022-12-13 14:20:25 +00:00
) :
2022-12-14 19:40:24 +00:00
original_file_path = file_path
2022-12-15 12:48:29 +00:00
save_as_bool = True if save_as . get ( " label " ) == " True " else False
if save_as_bool :
print ( " Save as... " )
2022-12-14 19:40:24 +00:00
file_path = filesavebox (
" Select the config file to save " ,
default = " finetune.json " ,
filetypes = " *.json " ,
)
2022-12-15 12:48:29 +00:00
else :
print ( " Save... " )
if file_path == None or file_path == " " :
file_path = filesavebox (
" Select the config file to save " ,
default = " finetune.json " ,
filetypes = " *.json " ,
)
2022-12-14 19:40:24 +00:00
if file_path == None :
2022-12-15 12:48:29 +00:00
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
2022-12-14 19:40:24 +00:00
2022-12-13 14:20:25 +00:00
# Return the values of the variables as a dictionary
variables = {
" pretrained_model_name_or_path " : pretrained_model_name_or_path ,
" v2 " : v2 ,
2022-12-13 19:59:33 +00:00
" v_parameterization " : v_parameterization ,
2022-12-13 14:20:25 +00:00
" logging_dir " : logging_dir ,
" train_data_dir " : train_data_dir ,
" reg_data_dir " : reg_data_dir ,
" output_dir " : output_dir ,
" max_resolution " : max_resolution ,
" learning_rate " : learning_rate ,
" lr_scheduler " : lr_scheduler ,
" lr_warmup " : lr_warmup ,
" train_batch_size " : train_batch_size ,
" epoch " : epoch ,
" save_every_n_epochs " : save_every_n_epochs ,
" mixed_precision " : mixed_precision ,
" save_precision " : save_precision ,
" seed " : seed ,
" num_cpu_threads_per_process " : num_cpu_threads_per_process ,
" convert_to_safetensors " : convert_to_safetensors ,
2022-12-13 16:07:32 +00:00
" convert_to_ckpt " : convert_to_ckpt ,
" cache_latent " : cache_latent ,
" caption_extention " : caption_extention ,
" use_safetensors " : use_safetensors ,
2022-12-13 18:49:14 +00:00
" enable_bucket " : enable_bucket ,
" gradient_checkpointing " : gradient_checkpointing ,
2022-12-14 02:21:59 +00:00
" full_fp16 " : full_fp16 ,
" no_token_padding " : no_token_padding ,
" stop_text_encoder_training " : stop_text_encoder_training ,
" use_8bit_adam " : use_8bit_adam ,
" xformers " : xformers ,
2022-12-13 14:20:25 +00:00
}
# Save the data to the selected file
with open ( file_path , " w " ) as file :
json . dump ( variables , file )
2022-12-14 19:40:24 +00:00
return file_path
2022-12-15 12:48:29 +00:00
def open_configuration (
2022-12-14 19:40:24 +00:00
file_path ,
pretrained_model_name_or_path ,
v2 ,
v_parameterization ,
logging_dir ,
train_data_dir ,
reg_data_dir ,
output_dir ,
max_resolution ,
learning_rate ,
lr_scheduler ,
lr_warmup ,
train_batch_size ,
epoch ,
save_every_n_epochs ,
mixed_precision ,
save_precision ,
seed ,
num_cpu_threads_per_process ,
convert_to_safetensors ,
convert_to_ckpt ,
cache_latent ,
caption_extention ,
use_safetensors ,
enable_bucket ,
gradient_checkpointing ,
full_fp16 ,
no_token_padding ,
stop_text_encoder_training ,
use_8bit_adam ,
xformers ,
) :
original_file_path = file_path
file_path = get_file_path ( file_path )
2022-12-13 14:20:25 +00:00
2022-12-14 19:40:24 +00:00
if file_path != " " and file_path != None :
print ( file_path )
# load variables from JSON file
with open ( file_path , " r " ) as f :
my_data = json . load ( f )
else :
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
my_data = { }
2022-12-13 14:20:25 +00:00
# Return the values of the variables as a dictionary
return (
2022-12-14 19:40:24 +00:00
file_path ,
2022-12-15 23:19:35 +00:00
my_data . get ( " pretrained_model_name_or_path " ,
pretrained_model_name_or_path ) ,
2022-12-14 19:40:24 +00:00
my_data . get ( " v2 " , v2 ) ,
my_data . get ( " v_parameterization " , v_parameterization ) ,
my_data . get ( " logging_dir " , logging_dir ) ,
my_data . get ( " train_data_dir " , train_data_dir ) ,
my_data . get ( " reg_data_dir " , reg_data_dir ) ,
my_data . get ( " output_dir " , output_dir ) ,
my_data . get ( " max_resolution " , max_resolution ) ,
my_data . get ( " learning_rate " , learning_rate ) ,
my_data . get ( " lr_scheduler " , lr_scheduler ) ,
my_data . get ( " lr_warmup " , lr_warmup ) ,
my_data . get ( " train_batch_size " , train_batch_size ) ,
my_data . get ( " epoch " , epoch ) ,
my_data . get ( " save_every_n_epochs " , save_every_n_epochs ) ,
my_data . get ( " mixed_precision " , mixed_precision ) ,
my_data . get ( " save_precision " , save_precision ) ,
my_data . get ( " seed " , seed ) ,
2022-12-15 23:19:35 +00:00
my_data . get ( " num_cpu_threads_per_process " ,
num_cpu_threads_per_process ) ,
2022-12-14 19:40:24 +00:00
my_data . get ( " convert_to_safetensors " , convert_to_safetensors ) ,
my_data . get ( " convert_to_ckpt " , convert_to_ckpt ) ,
my_data . get ( " cache_latent " , cache_latent ) ,
my_data . get ( " caption_extention " , caption_extention ) ,
my_data . get ( " use_safetensors " , use_safetensors ) ,
my_data . get ( " enable_bucket " , enable_bucket ) ,
my_data . get ( " gradient_checkpointing " , gradient_checkpointing ) ,
my_data . get ( " full_fp16 " , full_fp16 ) ,
my_data . get ( " no_token_padding " , no_token_padding ) ,
my_data . get ( " stop_text_encoder_training " , stop_text_encoder_training ) ,
my_data . get ( " use_8bit_adam " , use_8bit_adam ) ,
my_data . get ( " xformers " , xformers ) ,
2022-12-13 14:20:25 +00:00
)
def train_model (
pretrained_model_name_or_path ,
v2 ,
2022-12-13 19:59:33 +00:00
v_parameterization ,
2022-12-13 14:20:25 +00:00
logging_dir ,
train_data_dir ,
reg_data_dir ,
output_dir ,
max_resolution ,
learning_rate ,
lr_scheduler ,
lr_warmup ,
train_batch_size ,
epoch ,
save_every_n_epochs ,
mixed_precision ,
save_precision ,
seed ,
num_cpu_threads_per_process ,
convert_to_safetensors ,
convert_to_ckpt ,
2022-12-13 16:07:32 +00:00
cache_latent ,
caption_extention ,
use_safetensors ,
2022-12-13 18:49:14 +00:00
enable_bucket ,
gradient_checkpointing ,
2022-12-14 02:21:59 +00:00
full_fp16 ,
no_token_padding ,
stop_text_encoder_training_pct ,
use_8bit_adam ,
xformers ,
2022-12-13 14:20:25 +00:00
) :
2022-12-15 23:19:35 +00:00
2022-12-13 19:59:33 +00:00
def save_inference_file ( output_dir , v2 , v_parameterization ) :
2022-12-13 14:20:25 +00:00
# Copy inference model for v2 if required
2022-12-13 19:59:33 +00:00
if v2 and v_parameterization :
2022-12-13 14:20:25 +00:00
print ( f " Saving v2-inference-v.yaml as { output_dir } /last.yaml " )
shutil . copy (
f " ./v2_inference/v2-inference-v.yaml " ,
f " { output_dir } /last.yaml " ,
)
elif v2 :
print ( f " Saving v2-inference.yaml as { output_dir } /last.yaml " )
shutil . copy (
f " ./v2_inference/v2-inference.yaml " ,
f " { output_dir } /last.yaml " ,
)
2022-12-16 00:04:26 +00:00
if pretrained_model_name_or_path == " " :
msgbox ( " Source model information is missing " )
return
if train_data_dir == " " :
msgbox ( " Image folder path is missing " )
return
if not os . path . exists ( train_data_dir ) :
msgbox ( " Image folder does not exist " )
return
if reg_data_dir != " " :
if not os . path . exists ( reg_data_dir ) :
msgbox ( " Regularisation folder does not exist " )
return
if output_dir == " " :
msgbox ( " Output folder path is missing " )
return
2022-12-13 14:20:25 +00:00
# Get a list of all subfolders in train_data_dir
2022-12-14 02:21:59 +00:00
subfolders = [
2022-12-15 23:19:35 +00:00
f for f in os . listdir ( train_data_dir )
2022-12-14 02:21:59 +00:00
if os . path . isdir ( os . path . join ( train_data_dir , f ) )
]
2022-12-13 14:20:25 +00:00
total_steps = 0
# Loop through each subfolder and extract the number of repeats
for folder in subfolders :
# Extract the number of repeats from the folder name
repeats = int ( folder . split ( " _ " ) [ 0 ] )
# Count the number of images in the folder
2022-12-15 23:19:35 +00:00
num_images = len ( [
f for f in os . listdir ( os . path . join ( train_data_dir , folder ) )
if f . endswith ( " .jpg " ) or f . endswith ( " .jpeg " ) or f . endswith ( " .png " )
or f . endswith ( " .webp " )
] )
2022-12-13 14:20:25 +00:00
# Calculate the total number of steps for this folder
steps = repeats * num_images
total_steps + = steps
# Print the result
print ( f " Folder { folder } : { steps } steps " )
# Print the result
# print(f"{total_steps} total steps")
2022-12-13 22:52:25 +00:00
if reg_data_dir == " " :
reg_factor = 1
else :
2022-12-14 02:21:59 +00:00
print (
" Regularisation images are used... Will double the number of steps required... "
)
2022-12-13 22:52:25 +00:00
reg_factor = 2
2022-12-13 14:20:25 +00:00
# calculate max_train_steps
max_train_steps = int (
2022-12-14 02:21:59 +00:00
math . ceil (
2022-12-15 23:19:35 +00:00
float ( total_steps ) / int ( train_batch_size ) * int ( epoch ) *
int ( reg_factor ) ) )
2022-12-13 14:20:25 +00:00
print ( f " max_train_steps = { max_train_steps } " )
2022-12-14 02:21:59 +00:00
# calculate stop encoder training
if stop_text_encoder_training_pct == None :
stop_text_encoder_training = 0
else :
stop_text_encoder_training = math . ceil (
2022-12-15 23:19:35 +00:00
float ( max_train_steps ) / 100 * int ( stop_text_encoder_training_pct ) )
2022-12-14 02:21:59 +00:00
print ( f " stop_text_encoder_training = { stop_text_encoder_training } " )
2022-12-13 14:20:25 +00:00
lr_warmup_steps = round ( float ( int ( lr_warmup ) * int ( max_train_steps ) / 100 ) )
print ( f " lr_warmup_steps = { lr_warmup_steps } " )
run_cmd = f ' accelerate launch --num_cpu_threads_per_process= { num_cpu_threads_per_process } " train_db_fixed.py " '
if v2 :
run_cmd + = " --v2 "
2022-12-13 19:59:33 +00:00
if v_parameterization :
2022-12-13 14:20:25 +00:00
run_cmd + = " --v_parameterization "
2022-12-13 16:07:32 +00:00
if cache_latent :
2022-12-13 14:20:25 +00:00
run_cmd + = " --cache_latents "
2022-12-13 16:07:32 +00:00
if use_safetensors :
run_cmd + = " --use_safetensors "
if enable_bucket :
run_cmd + = " --enable_bucket "
2022-12-13 18:49:14 +00:00
if gradient_checkpointing :
run_cmd + = " --gradient_checkpointing "
if full_fp16 :
run_cmd + = " --full_fp16 "
2022-12-14 02:21:59 +00:00
if no_token_padding :
run_cmd + = " --no_token_padding "
if use_8bit_adam :
run_cmd + = " --use_8bit_adam "
if xformers :
run_cmd + = " --xformers "
2022-12-13 14:20:25 +00:00
run_cmd + = f " --pretrained_model_name_or_path= { pretrained_model_name_or_path } "
2022-12-15 23:19:35 +00:00
run_cmd + = f ' --train_data_dir= " { train_data_dir } " '
if len ( reg_data_dir ) :
run_cmd + = f ' --reg_data_dir= " { reg_data_dir } " '
2022-12-13 14:20:25 +00:00
run_cmd + = f " --resolution= { max_resolution } "
run_cmd + = f " --output_dir= { output_dir } "
run_cmd + = f " --train_batch_size= { train_batch_size } "
run_cmd + = f " --learning_rate= { learning_rate } "
run_cmd + = f " --lr_scheduler= { lr_scheduler } "
run_cmd + = f " --lr_warmup_steps= { lr_warmup_steps } "
run_cmd + = f " --max_train_steps= { max_train_steps } "
run_cmd + = f " --use_8bit_adam "
run_cmd + = f " --xformers "
run_cmd + = f " --mixed_precision= { mixed_precision } "
run_cmd + = f " --save_every_n_epochs= { save_every_n_epochs } "
run_cmd + = f " --seed= { seed } "
run_cmd + = f " --save_precision= { save_precision } "
run_cmd + = f " --logging_dir= { logging_dir } "
2022-12-13 16:07:32 +00:00
run_cmd + = f " --caption_extention= { caption_extention } "
2022-12-14 02:21:59 +00:00
run_cmd + = f " --stop_text_encoder_training= { stop_text_encoder_training } "
2022-12-13 14:20:25 +00:00
print ( run_cmd )
# Run the command
subprocess . run ( run_cmd )
# check if output_dir/last is a directory... therefore it is a diffuser model
last_dir = pathlib . Path ( f " { output_dir } /last " )
print ( last_dir )
if last_dir . is_dir ( ) :
if convert_to_ckpt :
print ( f " Converting diffuser model { last_dir } to { last_dir } .ckpt " )
os . system (
f " python ./tools/convert_diffusers20_original_sd.py { last_dir } { last_dir } .ckpt -- { save_precision } "
)
2022-12-13 19:59:33 +00:00
save_inference_file ( output_dir , v2 , v_parameterization )
2022-12-13 14:20:25 +00:00
if convert_to_safetensors :
2022-12-15 23:19:35 +00:00
print (
f " Converting diffuser model { last_dir } to { last_dir } .safetensors "
)
2022-12-13 14:20:25 +00:00
os . system (
f " python ./tools/convert_diffusers20_original_sd.py { last_dir } { last_dir } .safetensors -- { save_precision } "
)
2022-12-13 19:59:33 +00:00
save_inference_file ( output_dir , v2 , v_parameterization )
2022-12-13 14:20:25 +00:00
else :
# Copy inference model for v2 if required
2022-12-13 19:59:33 +00:00
save_inference_file ( output_dir , v2 , v_parameterization )
2022-12-13 14:20:25 +00:00
# Return the values of the variables as a dictionary
# return
2022-12-13 19:59:33 +00:00
def set_pretrained_model_name_or_path_input ( value , v2 , v_parameterization ) :
2022-12-13 14:20:25 +00:00
# define a list of substrings to search for
2022-12-14 02:21:59 +00:00
substrings_v2 = [
" stabilityai/stable-diffusion-2-1-base " ,
" stabilityai/stable-diffusion-2-base " ,
]
2022-12-13 14:20:25 +00:00
2022-12-13 19:59:33 +00:00
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
2022-12-13 14:20:25 +00:00
if str ( value ) in substrings_v2 :
print ( " SD v2 model detected. Setting --v2 parameter " )
v2 = True
2022-12-13 19:59:33 +00:00
v_parameterization = False
2022-12-13 14:20:25 +00:00
2022-12-13 19:59:33 +00:00
return value , v2 , v_parameterization
2022-12-13 14:20:25 +00:00
# define a list of substrings to search for v-objective
2022-12-14 02:21:59 +00:00
substrings_v_parameterization = [
" stabilityai/stable-diffusion-2-1 " ,
" stabilityai/stable-diffusion-2 " ,
]
2022-12-13 14:20:25 +00:00
2022-12-13 19:59:33 +00:00
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str ( value ) in substrings_v_parameterization :
2022-12-14 02:21:59 +00:00
print (
" SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization "
)
2022-12-13 14:20:25 +00:00
v2 = True
2022-12-13 19:59:33 +00:00
v_parameterization = True
2022-12-13 16:07:32 +00:00
2022-12-13 19:59:33 +00:00
return value , v2 , v_parameterization
2022-12-13 16:07:32 +00:00
# define a list of substrings to v1.x
2022-12-14 02:21:59 +00:00
substrings_v1_model = [
" CompVis/stable-diffusion-v1-4 " ,
" runwayml/stable-diffusion-v1-5 " ,
]
2022-12-13 16:07:32 +00:00
if str ( value ) in substrings_v1_model :
v2 = False
2022-12-13 19:59:33 +00:00
v_parameterization = False
2022-12-13 14:20:25 +00:00
2022-12-13 19:59:33 +00:00
return value , v2 , v_parameterization
2022-12-13 14:20:25 +00:00
if value == " custom " :
value = " "
v2 = False
2022-12-13 19:59:33 +00:00
v_parameterization = False
2022-12-13 14:20:25 +00:00
2022-12-13 19:59:33 +00:00
return value , v2 , v_parameterization
2022-12-13 14:20:25 +00:00
2022-12-14 19:40:24 +00:00
2022-12-14 02:21:59 +00:00
def remove_doublequote ( file_path ) :
if file_path != None :
2022-12-14 19:40:24 +00:00
file_path = file_path . replace ( ' " ' , " " )
2022-12-14 02:21:59 +00:00
return file_path
2022-12-14 19:40:24 +00:00
def get_file_path ( file_path ) :
2022-12-15 23:19:35 +00:00
file_path = fileopenbox ( " Select the config file to load " ,
default = file_path ,
filetypes = " *.json " )
2022-12-14 19:40:24 +00:00
return file_path
def get_folder_path ( ) :
folder_path = diropenbox ( " Select the directory to use " )
return folder_path
2022-12-15 23:19:35 +00:00
def dreambooth_folder_preparation (
util_training_images_dir_input ,
util_training_images_repeat_input ,
2022-12-16 00:32:53 +00:00
util_instance_prompt_input ,
2022-12-15 23:19:35 +00:00
util_regularization_images_dir_input ,
util_regularization_images_repeat_input ,
2022-12-16 00:32:53 +00:00
util_class_prompt_input ,
2022-12-15 23:19:35 +00:00
util_training_dir_input ,
) :
# Check if the input variables are empty
if ( not len ( util_training_dir_input ) ) :
print (
" Destination training directory is missing... can ' t perform the required task... "
)
return
else :
# Create the util_training_dir_input directory if it doesn't exist
os . makedirs ( util_training_dir_input , exist_ok = True )
# Create the training_dir path
2022-12-16 00:32:53 +00:00
if ( not len ( util_instance_prompt_input )
2022-12-15 23:19:35 +00:00
or not util_training_images_repeat_input > 0 ) :
print (
" Training images directory or repeats is missing... can ' t perform the required task... "
)
return
else :
training_dir = os . path . join (
util_training_dir_input ,
2022-12-16 00:32:53 +00:00
f " img/ { int ( util_training_images_repeat_input ) } _ { util_instance_prompt_input } { util_class_prompt_input } " ,
2022-12-15 23:19:35 +00:00
)
# Remove folders if they exist
if os . path . exists ( training_dir ) :
print ( f " Removing existing directory { training_dir } ... " )
shutil . rmtree ( training_dir )
# Copy the training images to their respective directories
print ( f " Copy { util_training_images_dir_input } to { training_dir } ... " )
shutil . copytree ( util_training_images_dir_input , training_dir )
# Create the regularization_dir path
2022-12-16 00:32:53 +00:00
if ( not ( util_class_prompt_input == " " )
2022-12-15 23:19:35 +00:00
or not util_regularization_images_repeat_input > 0 ) :
print (
" Regularization images directory or repeats is missing... not copying regularisation images... "
)
else :
regularization_dir = os . path . join (
util_training_dir_input ,
2022-12-16 00:32:53 +00:00
f " reg/ { int ( util_regularization_images_repeat_input ) } _ { util_class_prompt_input } " ,
2022-12-15 23:19:35 +00:00
)
# Remove folders if they exist
if os . path . exists ( regularization_dir ) :
print ( f " Removing existing directory { regularization_dir } ... " )
shutil . rmtree ( regularization_dir )
# Copy the regularisation images to their respective directories
print (
f " Copy { util_regularization_images_dir_input } to { regularization_dir } ... "
)
shutil . copytree ( util_regularization_images_dir_input ,
regularization_dir )
print (
f " Done creating kohya_ss training folder structure at { util_training_dir_input } ... "
)
2022-12-15 23:34:54 +00:00
2022-12-15 23:19:35 +00:00
def copy_info_to_Directories_tab ( training_folder ) :
img_folder = os . path . join ( training_folder , " img " )
2022-12-15 23:44:02 +00:00
if os . path . exists ( os . path . join ( training_folder , " reg " ) ) :
reg_folder = os . path . join ( training_folder , " reg " )
else :
reg_folder = " "
2022-12-15 23:19:35 +00:00
model_folder = os . path . join ( training_folder , " model " )
log_folder = os . path . join ( training_folder , " log " )
return img_folder , reg_folder , model_folder , log_folder
2022-12-15 23:34:54 +00:00
2022-12-14 19:40:24 +00:00
css = " "
if os . path . exists ( " ./style.css " ) :
with open ( os . path . join ( " ./style.css " ) , " r " , encoding = " utf8 " ) as file :
print ( " Load CSS... " )
css + = file . read ( ) + " \n "
interface = gr . Blocks ( css = css )
2022-12-13 14:20:25 +00:00
with interface :
2022-12-15 12:48:29 +00:00
dummy_true = gr . Label ( value = True , visible = False )
dummy_false = gr . Label ( value = False , visible = False )
2022-12-13 14:20:25 +00:00
gr . Markdown ( " Enter kohya finetuner parameter using this interface. " )
with gr . Accordion ( " Configuration File Load/Save " , open = False ) :
with gr . Row ( ) :
2022-12-14 19:40:24 +00:00
button_open_config = gr . Button ( " Open 📂 " , elem_id = " open_folder " )
button_save_config = gr . Button ( " Save 💾 " , elem_id = " open_folder " )
2022-12-15 23:19:35 +00:00
button_save_as_config = gr . Button ( " Save as... 💾 " ,
elem_id = " open_folder " )
2022-12-14 19:40:24 +00:00
config_file_name = gr . Textbox (
2022-12-16 00:32:53 +00:00
label = " " , placeholder = " type the configuration file path or use the ' Open ' button above to select it... " )
2022-12-15 23:19:35 +00:00
config_file_name . change ( remove_doublequote ,
inputs = [ config_file_name ] ,
outputs = [ config_file_name ] )
2022-12-13 16:07:32 +00:00
with gr . Tab ( " Source model " ) :
2022-12-13 14:20:25 +00:00
# Define the input elements
with gr . Row ( ) :
2022-12-13 16:26:21 +00:00
pretrained_model_name_or_path_input = gr . Textbox (
2022-12-13 14:20:25 +00:00
label = " Pretrained model name or path " ,
2022-12-15 23:19:35 +00:00
placeholder =
" enter the path to custom model or name of pretrained model " ,
2022-12-13 14:20:25 +00:00
)
model_list = gr . Dropdown (
2022-12-13 16:26:21 +00:00
label = " (Optional) Model Quick Pick " ,
2022-12-13 14:20:25 +00:00
choices = [
" custom " ,
2022-12-13 16:07:32 +00:00
" stabilityai/stable-diffusion-2-1-base " ,
" stabilityai/stable-diffusion-2-base " ,
" stabilityai/stable-diffusion-2-1 " ,
" stabilityai/stable-diffusion-2 " ,
" runwayml/stable-diffusion-v1-5 " ,
2022-12-14 02:21:59 +00:00
" CompVis/stable-diffusion-v1-4 " ,
2022-12-13 14:20:25 +00:00
] ,
)
with gr . Row ( ) :
2022-12-13 16:26:21 +00:00
v2_input = gr . Checkbox ( label = " v2 " , value = True )
2022-12-15 23:19:35 +00:00
v_parameterization_input = gr . Checkbox ( label = " v_parameterization " ,
value = False )
2022-12-14 02:21:59 +00:00
pretrained_model_name_or_path_input . change (
remove_doublequote ,
inputs = [ pretrained_model_name_or_path_input ] ,
2022-12-14 19:40:24 +00:00
outputs = [ pretrained_model_name_or_path_input ] ,
2022-12-14 02:21:59 +00:00
)
2022-12-13 14:20:25 +00:00
model_list . change (
set_pretrained_model_name_or_path_input ,
2022-12-13 19:59:33 +00:00
inputs = [ model_list , v2_input , v_parameterization_input ] ,
2022-12-14 02:21:59 +00:00
outputs = [
pretrained_model_name_or_path_input ,
v2_input ,
v_parameterization_input ,
] ,
2022-12-13 14:20:25 +00:00
)
2022-12-14 19:40:24 +00:00
2022-12-13 16:07:32 +00:00
with gr . Tab ( " Directories " ) :
2022-12-13 14:20:25 +00:00
with gr . Row ( ) :
2022-12-13 16:26:21 +00:00
train_data_dir_input = gr . Textbox (
2022-12-14 02:21:59 +00:00
label = " Image folder " ,
2022-12-15 23:19:35 +00:00
placeholder =
" Directory where the training folders containing the images are located " ,
2022-12-15 12:48:29 +00:00
)
2022-12-15 23:19:35 +00:00
train_data_dir_input_folder = gr . Button (
" 📂 " , elem_id = " open_folder_small " )
train_data_dir_input_folder . click ( get_folder_path ,
outputs = train_data_dir_input )
2022-12-13 16:26:21 +00:00
reg_data_dir_input = gr . Textbox (
2022-12-14 02:21:59 +00:00
label = " Regularisation folder " ,
2022-12-15 23:19:35 +00:00
placeholder =
" (Optional) Directory where where the regularization folders containing the images are located " ,
2022-12-13 16:07:32 +00:00
)
2022-12-15 23:19:35 +00:00
reg_data_dir_input_folder = gr . Button ( " 📂 " ,
elem_id = " open_folder_small " )
reg_data_dir_input_folder . click ( get_folder_path ,
outputs = reg_data_dir_input )
2022-12-13 16:07:32 +00:00
with gr . Row ( ) :
2022-12-13 16:26:21 +00:00
output_dir_input = gr . Textbox (
2022-12-13 16:07:32 +00:00
label = " Output directory " ,
2022-12-13 22:52:25 +00:00
placeholder = " Directory to output trained model " ,
2022-12-13 16:07:32 +00:00
)
2022-12-15 23:19:35 +00:00
output_dir_input_folder = gr . Button ( " 📂 " ,
elem_id = " open_folder_small " )
output_dir_input_folder . click ( get_folder_path ,
outputs = output_dir_input )
2022-12-13 16:26:21 +00:00
logging_dir_input = gr . Textbox (
2022-12-14 02:21:59 +00:00
label = " Logging directory " ,
2022-12-15 23:19:35 +00:00
placeholder =
" Optional: enable logging and output TensorBoard log to this directory " ,
2022-12-13 16:07:32 +00:00
)
2022-12-15 23:19:35 +00:00
logging_dir_input_folder = gr . Button ( " 📂 " ,
elem_id = " open_folder_small " )
logging_dir_input_folder . click ( get_folder_path ,
outputs = logging_dir_input )
2022-12-14 02:21:59 +00:00
train_data_dir_input . change (
remove_doublequote ,
inputs = [ train_data_dir_input ] ,
2022-12-14 19:40:24 +00:00
outputs = [ train_data_dir_input ] ,
2022-12-14 02:21:59 +00:00
)
reg_data_dir_input . change (
remove_doublequote ,
inputs = [ reg_data_dir_input ] ,
2022-12-14 19:40:24 +00:00
outputs = [ reg_data_dir_input ] ,
2022-12-14 02:21:59 +00:00
)
2022-12-15 23:19:35 +00:00
output_dir_input . change ( remove_doublequote ,
inputs = [ output_dir_input ] ,
outputs = [ output_dir_input ] )
logging_dir_input . change ( remove_doublequote ,
inputs = [ logging_dir_input ] ,
outputs = [ logging_dir_input ] )
2022-12-13 16:07:32 +00:00
with gr . Tab ( " Training parameters " ) :
with gr . Row ( ) :
2022-12-14 02:21:59 +00:00
learning_rate_input = gr . Textbox ( label = " Learning rate " , value = 1e-6 )
2022-12-13 14:20:25 +00:00
lr_scheduler_input = gr . Dropdown (
label = " LR Scheduler " ,
choices = [
" constant " ,
" constant_with_warmup " ,
" cosine " ,
" cosine_with_restarts " ,
" linear " ,
" polynomial " ,
] ,
value = " constant " ,
)
2022-12-13 16:26:21 +00:00
lr_warmup_input = gr . Textbox ( label = " LR warmup " , value = 0 )
2022-12-13 14:20:25 +00:00
with gr . Row ( ) :
2022-12-15 23:19:35 +00:00
train_batch_size_input = gr . Slider ( minimum = 1 ,
maximum = 32 ,
label = " Train batch size " ,
value = 1 ,
step = 1 )
2022-12-13 16:26:21 +00:00
epoch_input = gr . Textbox ( label = " Epoch " , value = 1 )
2022-12-15 23:19:35 +00:00
save_every_n_epochs_input = gr . Textbox ( label = " Save every N epochs " ,
value = 1 )
2022-12-13 16:07:32 +00:00
with gr . Row ( ) :
2022-12-13 14:20:25 +00:00
mixed_precision_input = gr . Dropdown (
label = " Mixed precision " ,
choices = [
" no " ,
" fp16 " ,
" bf16 " ,
] ,
value = " fp16 " ,
)
save_precision_input = gr . Dropdown (
label = " Save precision " ,
choices = [
" float " ,
" fp16 " ,
" bf16 " ,
] ,
value = " fp16 " ,
)
2022-12-14 02:21:59 +00:00
num_cpu_threads_per_process_input = gr . Slider (
minimum = 1 ,
maximum = os . cpu_count ( ) ,
step = 1 ,
label = " Number of CPU threads per process " ,
value = os . cpu_count ( ) ,
2022-12-13 14:20:25 +00:00
)
2022-12-13 16:07:32 +00:00
with gr . Row ( ) :
2022-12-13 16:26:21 +00:00
seed_input = gr . Textbox ( label = " Seed " , value = 1234 )
2022-12-15 23:19:35 +00:00
max_resolution_input = gr . Textbox ( label = " Max resolution " ,
2022-12-16 00:32:53 +00:00
value = " 512,512 " ,
placeholder = " 512,512 " )
2022-12-14 02:21:59 +00:00
with gr . Row ( ) :
2022-12-13 16:26:21 +00:00
caption_extention_input = gr . Textbox (
2022-12-14 02:21:59 +00:00
label = " Caption Extension " ,
2022-12-15 23:19:35 +00:00
placeholder =
" (Optional) Extension for caption files. default: .caption " ,
2022-12-14 02:21:59 +00:00
)
2022-12-14 19:40:24 +00:00
stop_text_encoder_training_input = gr . Slider (
minimum = 0 ,
maximum = 100 ,
value = 0 ,
step = 1 ,
2022-12-14 02:21:59 +00:00
label = " Stop text encoder training " ,
)
2022-12-13 16:07:32 +00:00
with gr . Row ( ) :
2022-12-16 00:04:26 +00:00
full_fp16_input = gr . Checkbox (
label = " Full fp16 training (experimental) " , value = False )
no_token_padding_input = gr . Checkbox ( label = " No token padding " ,
value = False )
2022-12-13 16:26:21 +00:00
use_safetensors_input = gr . Checkbox (
2022-12-15 23:19:35 +00:00
label = " Use safetensor when saving " , value = False )
2022-12-16 00:04:26 +00:00
2022-12-13 18:49:14 +00:00
gradient_checkpointing_input = gr . Checkbox (
2022-12-15 23:19:35 +00:00
label = " Gradient checkpointing " , value = False )
2022-12-14 02:21:59 +00:00
with gr . Row ( ) :
2022-12-16 00:04:26 +00:00
enable_bucket_input = gr . Checkbox ( label = " Enable buckets " ,
value = True )
cache_latent_input = gr . Checkbox ( label = " Cache latent " , value = True )
2022-12-15 23:19:35 +00:00
use_8bit_adam_input = gr . Checkbox ( label = " Use 8bit adam " ,
value = True )
xformers_input = gr . Checkbox ( label = " Use xformers " , value = True )
2022-12-13 16:07:32 +00:00
with gr . Tab ( " Model conversion " ) :
2022-12-13 16:26:21 +00:00
convert_to_safetensors_input = gr . Checkbox (
2022-12-16 00:04:26 +00:00
label = " Convert to SafeTensors " , value = True )
2022-12-15 23:19:35 +00:00
convert_to_ckpt_input = gr . Checkbox ( label = " Convert to CKPT " ,
value = False )
with gr . Tab ( " Utilities " ) :
with gr . Tab ( " Dreambooth folder preparation " ) :
gr . Markdown (
2022-12-16 00:32:53 +00:00
" This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth method to function correctly. "
2022-12-15 23:19:35 +00:00
)
2022-12-15 23:31:36 +00:00
with gr . Row ( ) :
2022-12-16 00:32:53 +00:00
util_instance_prompt_input = gr . Textbox (
2022-12-15 23:31:36 +00:00
label = " Instance prompt " ,
placeholder = " Eg: asd " ,
interactive = True ,
)
2022-12-16 00:32:53 +00:00
util_class_prompt_input = gr . Textbox (
2022-12-15 23:31:36 +00:00
label = " Class prompt " ,
2022-12-15 23:34:54 +00:00
placeholder = " Eg: person " ,
2022-12-15 23:31:36 +00:00
interactive = True ,
)
2022-12-15 23:19:35 +00:00
with gr . Row ( ) :
util_training_images_dir_input = gr . Textbox (
label = " Training images " ,
placeholder = " Directory containing the training images " ,
interactive = True ,
)
button_util_training_images_dir_input = gr . Button (
" 📂 " , elem_id = " open_folder_small " )
button_util_training_images_dir_input . click (
get_folder_path , outputs = util_training_images_dir_input )
util_training_images_repeat_input = gr . Number (
label = " Repeats " ,
value = 40 ,
interactive = True ,
elem_id = " number_input " )
with gr . Row ( ) :
util_regularization_images_dir_input = gr . Textbox (
label = " Regularisation images " ,
placeholder =
2022-12-15 23:31:36 +00:00
" (Optional) Directory containing the regularisation images " ,
2022-12-15 23:19:35 +00:00
interactive = True ,
)
button_util_regularization_images_dir_input = gr . Button (
" 📂 " , elem_id = " open_folder_small " )
button_util_regularization_images_dir_input . click (
get_folder_path ,
outputs = util_regularization_images_dir_input )
util_regularization_images_repeat_input = gr . Number (
label = " Repeats " ,
value = 1 ,
interactive = True ,
elem_id = " number_input " )
with gr . Row ( ) :
util_training_dir_input = gr . Textbox (
label = " Destination training directory " ,
placeholder =
2022-12-16 00:32:53 +00:00
" Directory where formatted training and regularisation folders will be placed " ,
2022-12-15 23:19:35 +00:00
interactive = True ,
)
button_util_training_dir_input = gr . Button (
" 📂 " , elem_id = " open_folder_small " )
button_util_training_dir_input . click (
get_folder_path , outputs = util_training_dir_input )
button_prepare_training_data = gr . Button ( " Prepare training data " )
button_prepare_training_data . click (
dreambooth_folder_preparation ,
inputs = [
util_training_images_dir_input ,
util_training_images_repeat_input ,
2022-12-16 00:32:53 +00:00
util_instance_prompt_input ,
2022-12-15 23:19:35 +00:00
util_regularization_images_dir_input ,
util_regularization_images_repeat_input ,
2022-12-16 00:32:53 +00:00
util_class_prompt_input ,
2022-12-15 23:19:35 +00:00
util_training_dir_input ,
] ,
)
button_copy_info_to_Directories_tab = gr . Button (
" Copy info to Directories Tab " )
button_run = gr . Button ( " Train model " )
2022-12-13 16:07:32 +00:00
2022-12-15 23:34:54 +00:00
button_copy_info_to_Directories_tab . click ( copy_info_to_Directories_tab ,
inputs = [ util_training_dir_input ] ,
outputs = [
train_data_dir_input ,
reg_data_dir_input ,
output_dir_input ,
logging_dir_input
] )
2022-12-13 14:20:25 +00:00
2022-12-14 19:40:24 +00:00
button_open_config . click (
2022-12-15 12:48:29 +00:00
open_configuration ,
2022-12-14 19:40:24 +00:00
inputs = [
config_file_name ,
pretrained_model_name_or_path_input ,
v2_input ,
v_parameterization_input ,
logging_dir_input ,
train_data_dir_input ,
reg_data_dir_input ,
output_dir_input ,
max_resolution_input ,
learning_rate_input ,
lr_scheduler_input ,
lr_warmup_input ,
train_batch_size_input ,
epoch_input ,
save_every_n_epochs_input ,
mixed_precision_input ,
save_precision_input ,
seed_input ,
num_cpu_threads_per_process_input ,
convert_to_safetensors_input ,
convert_to_ckpt_input ,
cache_latent_input ,
caption_extention_input ,
use_safetensors_input ,
enable_bucket_input ,
gradient_checkpointing_input ,
full_fp16_input ,
no_token_padding_input ,
stop_text_encoder_training_input ,
use_8bit_adam_input ,
xformers_input ,
] ,
2022-12-13 14:20:25 +00:00
outputs = [
2022-12-14 19:40:24 +00:00
config_file_name ,
2022-12-13 14:20:25 +00:00
pretrained_model_name_or_path_input ,
v2_input ,
2022-12-13 19:59:33 +00:00
v_parameterization_input ,
2022-12-13 14:20:25 +00:00
logging_dir_input ,
train_data_dir_input ,
reg_data_dir_input ,
output_dir_input ,
max_resolution_input ,
learning_rate_input ,
lr_scheduler_input ,
lr_warmup_input ,
train_batch_size_input ,
epoch_input ,
save_every_n_epochs_input ,
mixed_precision_input ,
save_precision_input ,
seed_input ,
num_cpu_threads_per_process_input ,
convert_to_safetensors_input ,
2022-12-13 16:07:32 +00:00
convert_to_ckpt_input ,
cache_latent_input ,
caption_extention_input ,
use_safetensors_input ,
2022-12-13 18:49:14 +00:00
enable_bucket_input ,
gradient_checkpointing_input ,
2022-12-14 02:21:59 +00:00
full_fp16_input ,
no_token_padding_input ,
stop_text_encoder_training_input ,
use_8bit_adam_input ,
xformers_input ,
] ,
2022-12-13 14:20:25 +00:00
)
2022-12-13 16:07:32 +00:00
2022-12-15 12:48:29 +00:00
save_as = True
not_save_as = False
2022-12-13 18:49:14 +00:00
button_save_config . click (
2022-12-15 12:48:29 +00:00
save_configuration ,
2022-12-13 14:20:25 +00:00
inputs = [
2022-12-15 12:48:29 +00:00
dummy_false ,
2022-12-13 14:20:25 +00:00
config_file_name ,
pretrained_model_name_or_path_input ,
v2_input ,
2022-12-13 19:59:33 +00:00
v_parameterization_input ,
2022-12-13 14:20:25 +00:00
logging_dir_input ,
train_data_dir_input ,
reg_data_dir_input ,
output_dir_input ,
max_resolution_input ,
learning_rate_input ,
lr_scheduler_input ,
lr_warmup_input ,
train_batch_size_input ,
epoch_input ,
save_every_n_epochs_input ,
mixed_precision_input ,
save_precision_input ,
seed_input ,
num_cpu_threads_per_process_input ,
convert_to_safetensors_input ,
2022-12-13 16:07:32 +00:00
convert_to_ckpt_input ,
cache_latent_input ,
caption_extention_input ,
use_safetensors_input ,
2022-12-13 18:49:14 +00:00
enable_bucket_input ,
gradient_checkpointing_input ,
2022-12-14 02:21:59 +00:00
full_fp16_input ,
no_token_padding_input ,
stop_text_encoder_training_input ,
use_8bit_adam_input ,
xformers_input ,
] ,
2022-12-14 19:40:24 +00:00
outputs = [ config_file_name ] ,
)
button_save_as_config . click (
2022-12-15 12:48:29 +00:00
save_configuration ,
2022-12-14 19:40:24 +00:00
inputs = [
2022-12-15 12:48:29 +00:00
dummy_true ,
2022-12-14 19:40:24 +00:00
config_file_name ,
pretrained_model_name_or_path_input ,
v2_input ,
v_parameterization_input ,
logging_dir_input ,
train_data_dir_input ,
reg_data_dir_input ,
output_dir_input ,
max_resolution_input ,
learning_rate_input ,
lr_scheduler_input ,
lr_warmup_input ,
train_batch_size_input ,
epoch_input ,
save_every_n_epochs_input ,
mixed_precision_input ,
save_precision_input ,
seed_input ,
num_cpu_threads_per_process_input ,
convert_to_safetensors_input ,
convert_to_ckpt_input ,
cache_latent_input ,
caption_extention_input ,
use_safetensors_input ,
enable_bucket_input ,
gradient_checkpointing_input ,
full_fp16_input ,
no_token_padding_input ,
stop_text_encoder_training_input ,
use_8bit_adam_input ,
xformers_input ,
] ,
outputs = [ config_file_name ] ,
2022-12-13 14:20:25 +00:00
)
2022-12-14 19:40:24 +00:00
2022-12-13 18:49:14 +00:00
button_run . click (
2022-12-13 14:20:25 +00:00
train_model ,
inputs = [
pretrained_model_name_or_path_input ,
v2_input ,
2022-12-13 19:59:33 +00:00
v_parameterization_input ,
2022-12-13 14:20:25 +00:00
logging_dir_input ,
train_data_dir_input ,
reg_data_dir_input ,
output_dir_input ,
max_resolution_input ,
learning_rate_input ,
lr_scheduler_input ,
lr_warmup_input ,
train_batch_size_input ,
epoch_input ,
save_every_n_epochs_input ,
mixed_precision_input ,
save_precision_input ,
seed_input ,
num_cpu_threads_per_process_input ,
convert_to_safetensors_input ,
convert_to_ckpt_input ,
2022-12-13 16:07:32 +00:00
cache_latent_input ,
caption_extention_input ,
use_safetensors_input ,
2022-12-13 18:49:14 +00:00
enable_bucket_input ,
gradient_checkpointing_input ,
2022-12-14 02:21:59 +00:00
full_fp16_input ,
no_token_padding_input ,
stop_text_encoder_training_input ,
use_8bit_adam_input ,
xformers_input ,
] ,
2022-12-13 14:20:25 +00:00
)
# Show the interface
interface . launch ( )