Adding improved elements to GUI

This commit is contained in:
bmaltais 2022-12-14 14:40:24 -05:00
parent 469b15b579
commit 3834e5dbab
3 changed files with 319 additions and 69 deletions

View File

@ -1,4 +1,5 @@
# v1: initial release # v1: initial release
# v2: add open and save folder icons
import gradio as gr import gradio as gr
import json import json
@ -9,6 +10,7 @@ import pathlib
import shutil import shutil
from glob import glob from glob import glob
from os.path import join from os.path import join
from easygui import fileopenbox, filesavebox, enterbox, diropenbox, msgbox
def save_variables( def save_variables(
@ -44,6 +46,19 @@ def save_variables(
use_8bit_adam, use_8bit_adam,
xformers, xformers,
): ):
original_file_path = file_path
if file_path == None or file_path == "":
file_path = filesavebox(
"Select the config file to save",
default="finetune.json",
filetypes="*.json",
)
if file_path == None:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
return file_path
# Return the values of the variables as a dictionary # Return the values of the variables as a dictionary
variables = { variables = {
"pretrained_model_name_or_path": pretrained_model_name_or_path, "pretrained_model_name_or_path": pretrained_model_name_or_path,
@ -82,44 +97,172 @@ def save_variables(
with open(file_path, "w") as file: with open(file_path, "w") as file:
json.dump(variables, file) json.dump(variables, file)
return file_path
def load_variables(file_path):
def save_as_variables(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization,
logging_dir,
train_data_dir,
reg_data_dir,
output_dir,
max_resolution,
learning_rate,
lr_scheduler,
lr_warmup,
train_batch_size,
epoch,
save_every_n_epochs,
mixed_precision,
save_precision,
seed,
num_cpu_threads_per_process,
convert_to_safetensors,
convert_to_ckpt,
cache_latent,
caption_extention,
use_safetensors,
enable_bucket,
gradient_checkpointing,
full_fp16,
no_token_padding,
stop_text_encoder_training,
use_8bit_adam,
xformers,
):
original_file_path = file_path
file_path = filesavebox(
"Select the config file to save", default="finetune.json", filetypes="*.json"
)
if file_path == None:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
return file_path
# Return the values of the variables as a dictionary
variables = {
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"v2": v2,
"v_parameterization": v_parameterization,
"logging_dir": logging_dir,
"train_data_dir": train_data_dir,
"reg_data_dir": reg_data_dir,
"output_dir": output_dir,
"max_resolution": max_resolution,
"learning_rate": learning_rate,
"lr_scheduler": lr_scheduler,
"lr_warmup": lr_warmup,
"train_batch_size": train_batch_size,
"epoch": epoch,
"save_every_n_epochs": save_every_n_epochs,
"mixed_precision": mixed_precision,
"save_precision": save_precision,
"seed": seed,
"num_cpu_threads_per_process": num_cpu_threads_per_process,
"convert_to_safetensors": convert_to_safetensors,
"convert_to_ckpt": convert_to_ckpt,
"cache_latent": cache_latent,
"caption_extention": caption_extention,
"use_safetensors": use_safetensors,
"enable_bucket": enable_bucket,
"gradient_checkpointing": gradient_checkpointing,
"full_fp16": full_fp16,
"no_token_padding": no_token_padding,
"stop_text_encoder_training": stop_text_encoder_training,
"use_8bit_adam": use_8bit_adam,
"xformers": xformers,
}
# Save the data to the selected file
with open(file_path, "w") as file:
json.dump(variables, file)
return file_path
def open_config_file(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization,
logging_dir,
train_data_dir,
reg_data_dir,
output_dir,
max_resolution,
learning_rate,
lr_scheduler,
lr_warmup,
train_batch_size,
epoch,
save_every_n_epochs,
mixed_precision,
save_precision,
seed,
num_cpu_threads_per_process,
convert_to_safetensors,
convert_to_ckpt,
cache_latent,
caption_extention,
use_safetensors,
enable_bucket,
gradient_checkpointing,
full_fp16,
no_token_padding,
stop_text_encoder_training,
use_8bit_adam,
xformers,
):
original_file_path = file_path
file_path = get_file_path(file_path)
if file_path != "" and file_path != None:
print(file_path)
# load variables from JSON file # load variables from JSON file
with open(file_path, "r") as f: with open(file_path, "r") as f:
my_data = json.load(f) my_data = json.load(f)
else:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
my_data = {}
# Return the values of the variables as a dictionary # Return the values of the variables as a dictionary
return ( return (
my_data.get("pretrained_model_name_or_path", None), file_path,
my_data.get("v2", None), my_data.get("pretrained_model_name_or_path", pretrained_model_name_or_path),
my_data.get("v_parameterization", None), my_data.get("v2", v2),
my_data.get("logging_dir", None), my_data.get("v_parameterization", v_parameterization),
my_data.get("train_data_dir", None), my_data.get("logging_dir", logging_dir),
my_data.get("reg_data_dir", None), my_data.get("train_data_dir", train_data_dir),
my_data.get("output_dir", None), my_data.get("reg_data_dir", reg_data_dir),
my_data.get("max_resolution", None), my_data.get("output_dir", output_dir),
my_data.get("learning_rate", None), my_data.get("max_resolution", max_resolution),
my_data.get("lr_scheduler", None), my_data.get("learning_rate", learning_rate),
my_data.get("lr_warmup", None), my_data.get("lr_scheduler", lr_scheduler),
my_data.get("train_batch_size", None), my_data.get("lr_warmup", lr_warmup),
my_data.get("epoch", None), my_data.get("train_batch_size", train_batch_size),
my_data.get("save_every_n_epochs", None), my_data.get("epoch", epoch),
my_data.get("mixed_precision", None), my_data.get("save_every_n_epochs", save_every_n_epochs),
my_data.get("save_precision", None), my_data.get("mixed_precision", mixed_precision),
my_data.get("seed", None), my_data.get("save_precision", save_precision),
my_data.get("num_cpu_threads_per_process", None), my_data.get("seed", seed),
my_data.get("convert_to_safetensors", None), my_data.get("num_cpu_threads_per_process", num_cpu_threads_per_process),
my_data.get("convert_to_ckpt", None), my_data.get("convert_to_safetensors", convert_to_safetensors),
my_data.get("cache_latent", None), my_data.get("convert_to_ckpt", convert_to_ckpt),
my_data.get("caption_extention", None), my_data.get("cache_latent", cache_latent),
my_data.get("use_safetensors", None), my_data.get("caption_extention", caption_extention),
my_data.get("enable_bucket", None), my_data.get("use_safetensors", use_safetensors),
my_data.get("gradient_checkpointing", None), my_data.get("enable_bucket", enable_bucket),
my_data.get("full_fp16", None), my_data.get("gradient_checkpointing", gradient_checkpointing),
my_data.get("no_token_padding", None), my_data.get("full_fp16", full_fp16),
my_data.get("stop_text_encoder_training", None), my_data.get("no_token_padding", no_token_padding),
my_data.get("use_8bit_adam", None), my_data.get("stop_text_encoder_training", stop_text_encoder_training),
my_data.get("xformers", None), my_data.get("use_8bit_adam", use_8bit_adam),
my_data.get("xformers", xformers),
) )
@ -356,28 +499,49 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
return value, v2, v_parameterization return value, v2, v_parameterization
def remove_doublequote(file_path): def remove_doublequote(file_path):
if file_path != None: if file_path != None:
file_path = file_path.replace('"', '') file_path = file_path.replace('"', "")
return file_path return file_path
interface = gr.Blocks() def get_file_path(file_path):
file_path = fileopenbox(
"Select the config file to load", default=file_path, filetypes="*.json"
)
return file_path
def get_folder_path():
folder_path = diropenbox("Select the directory to use")
return folder_path
css = ""
if os.path.exists("./style.css"):
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
print("Load CSS...")
css += file.read() + "\n"
interface = gr.Blocks(css=css)
with interface: with interface:
gr.Markdown("Enter kohya finetuner parameter using this interface.") gr.Markdown("Enter kohya finetuner parameter using this interface.")
with gr.Accordion("Configuration File Load/Save", open=False): with gr.Accordion("Configuration File Load/Save", open=False):
with gr.Row(): with gr.Row():
config_file_name = gr.Textbox(label="Config file name") button_open_config = gr.Button("Open 📂", elem_id="open_folder")
button_load_config = gr.Button("Load config") button_save_config = gr.Button("Save 💾", elem_id="open_folder")
button_save_config = gr.Button("Save config") button_save_as_config = gr.Button("Save as... 💾", elem_id="open_folder")
config_file_name = gr.Textbox(
label="", placeholder="type config file path or use buttons..."
)
config_file_name.change( config_file_name.change(
remove_doublequote, remove_doublequote, inputs=[config_file_name], outputs=[config_file_name]
inputs=[config_file_name],
outputs=[
config_file_name
]
) )
with gr.Tab("Source model"): with gr.Tab("Source model"):
# Define the input elements # Define the input elements
@ -406,9 +570,7 @@ with interface:
pretrained_model_name_or_path_input.change( pretrained_model_name_or_path_input.change(
remove_doublequote, remove_doublequote,
inputs=[pretrained_model_name_or_path_input], inputs=[pretrained_model_name_or_path_input],
outputs=[ outputs=[pretrained_model_name_or_path_input],
pretrained_model_name_or_path_input
]
) )
model_list.change( model_list.change(
set_pretrained_model_name_or_path_input, set_pretrained_model_name_or_path_input,
@ -426,46 +588,42 @@ with interface:
label="Image folder", label="Image folder",
placeholder="Directory where the training folders containing the images are located", placeholder="Directory where the training folders containing the images are located",
) )
train_data_dir_input_folder = gr.Button("📂", elem_id="open_folder_small")
train_data_dir_input_folder.click(get_folder_path, outputs=train_data_dir_input)
reg_data_dir_input = gr.Textbox( reg_data_dir_input = gr.Textbox(
label="Regularisation folder", label="Regularisation folder",
placeholder="(Optional) Directory where where the regularization folders containing the images are located", placeholder="(Optional) Directory where where the regularization folders containing the images are located",
) )
reg_data_dir_input_folder = gr.Button("📂", elem_id="open_folder_small")
reg_data_dir_input_folder.click(get_folder_path, outputs=reg_data_dir_input)
with gr.Row(): with gr.Row():
output_dir_input = gr.Textbox( output_dir_input = gr.Textbox(
label="Output directory", label="Output directory",
placeholder="Directory to output trained model", placeholder="Directory to output trained model",
) )
output_dir_input_folder = gr.Button("📂", elem_id="open_folder_small")
output_dir_input_folder.click(get_folder_path, outputs=output_dir_input)
logging_dir_input = gr.Textbox( logging_dir_input = gr.Textbox(
label="Logging directory", label="Logging directory",
placeholder="Optional: enable logging and output TensorBoard log to this directory", placeholder="Optional: enable logging and output TensorBoard log to this directory",
) )
logging_dir_input_folder = gr.Button("📂", elem_id="open_folder_small")
logging_dir_input_folder.click(get_folder_path, outputs=logging_dir_input)
train_data_dir_input.change( train_data_dir_input.change(
remove_doublequote, remove_doublequote,
inputs=[train_data_dir_input], inputs=[train_data_dir_input],
outputs=[ outputs=[train_data_dir_input],
train_data_dir_input
]
) )
reg_data_dir_input.change( reg_data_dir_input.change(
remove_doublequote, remove_doublequote,
inputs=[reg_data_dir_input], inputs=[reg_data_dir_input],
outputs=[ outputs=[reg_data_dir_input],
reg_data_dir_input
]
) )
output_dir_input.change( output_dir_input.change(
remove_doublequote, remove_doublequote, inputs=[output_dir_input], outputs=[output_dir_input]
inputs=[output_dir_input],
outputs=[
output_dir_input
]
) )
logging_dir_input.change( logging_dir_input.change(
remove_doublequote, remove_doublequote, inputs=[logging_dir_input], outputs=[logging_dir_input]
inputs=[logging_dir_input],
outputs=[
logging_dir_input
]
) )
with gr.Tab("Training parameters"): with gr.Tab("Training parameters"):
with gr.Row(): with gr.Row():
@ -523,7 +681,11 @@ with interface:
label="Caption Extension", label="Caption Extension",
placeholder="(Optional) Extension for caption files. default: .caption", placeholder="(Optional) Extension for caption files. default: .caption",
) )
stop_text_encoder_training_input = gr.Slider(minimum=0, maximum=100, value=0, step=1, stop_text_encoder_training_input = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=1,
label="Stop text encoder training", label="Stop text encoder training",
) )
with gr.Row(): with gr.Row():
@ -551,10 +713,43 @@ with interface:
button_run = gr.Button("Run") button_run = gr.Button("Run")
button_load_config.click( button_open_config.click(
load_variables, open_config_file,
inputs=[config_file_name], inputs=[
config_file_name,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
logging_dir_input,
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input,
caption_extention_input,
use_safetensors_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
],
outputs=[ outputs=[
config_file_name,
pretrained_model_name_or_path_input, pretrained_model_name_or_path_input,
v2_input, v2_input,
v_parameterization_input, v_parameterization_input,
@ -623,7 +818,47 @@ with interface:
use_8bit_adam_input, use_8bit_adam_input,
xformers_input, xformers_input,
], ],
outputs=[config_file_name],
) )
button_save_as_config.click(
save_as_variables,
inputs=[
config_file_name,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
logging_dir_input,
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input,
caption_extention_input,
use_safetensors_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
],
outputs=[config_file_name],
)
button_run.click( button_run.click(
train_model, train_model,
inputs=[ inputs=[

View File

@ -11,3 +11,4 @@ tensorboard
safetensors==0.2.6 safetensors==0.2.6
gradio gradio
altair altair
easygui

14
style.css Normal file
View File

@ -0,0 +1,14 @@
#open_folder_small{
height: fit-content;
min-width: auto;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
}
#open_folder{
height: fit-content;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
}