diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 2694577..67ee686 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -1,4 +1,5 @@ # v1: initial release +# v2: add open and save folder icons import gradio as gr import json @@ -9,6 +10,7 @@ import pathlib import shutil from glob import glob from os.path import join +from easygui import fileopenbox, filesavebox, enterbox, diropenbox, msgbox def save_variables( @@ -44,6 +46,103 @@ def save_variables( use_8bit_adam, xformers, ): + original_file_path = file_path + + if file_path == None or file_path == "": + file_path = filesavebox( + "Select the config file to save", + default="finetune.json", + filetypes="*.json", + ) + + if file_path == None: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + return file_path + + # Return the values of the variables as a dictionary + variables = { + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "v2": v2, + "v_parameterization": v_parameterization, + "logging_dir": logging_dir, + "train_data_dir": train_data_dir, + "reg_data_dir": reg_data_dir, + "output_dir": output_dir, + "max_resolution": max_resolution, + "learning_rate": learning_rate, + "lr_scheduler": lr_scheduler, + "lr_warmup": lr_warmup, + "train_batch_size": train_batch_size, + "epoch": epoch, + "save_every_n_epochs": save_every_n_epochs, + "mixed_precision": mixed_precision, + "save_precision": save_precision, + "seed": seed, + "num_cpu_threads_per_process": num_cpu_threads_per_process, + "convert_to_safetensors": convert_to_safetensors, + "convert_to_ckpt": convert_to_ckpt, + "cache_latent": cache_latent, + "caption_extention": caption_extention, + "use_safetensors": use_safetensors, + "enable_bucket": enable_bucket, + "gradient_checkpointing": gradient_checkpointing, + "full_fp16": full_fp16, + "no_token_padding": no_token_padding, + "stop_text_encoder_training": stop_text_encoder_training, + "use_8bit_adam": use_8bit_adam, + "xformers": xformers, + } + + # Save the data to the selected file + with open(file_path, "w") as file: + json.dump(variables, file) + + return file_path + + +def save_as_variables( + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + convert_to_safetensors, + convert_to_ckpt, + cache_latent, + caption_extention, + use_safetensors, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + use_8bit_adam, + xformers, +): + original_file_path = file_path + + file_path = filesavebox( + "Select the config file to save", default="finetune.json", filetypes="*.json" + ) + + if file_path == None: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + return file_path + # Return the values of the variables as a dictionary variables = { "pretrained_model_name_or_path": pretrained_model_name_or_path, @@ -82,44 +181,88 @@ def save_variables( with open(file_path, "w") as file: json.dump(variables, file) + return file_path -def load_variables(file_path): - # load variables from JSON file - with open(file_path, "r") as f: - my_data = json.load(f) + +def open_config_file( + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + convert_to_safetensors, + convert_to_ckpt, + cache_latent, + caption_extention, + use_safetensors, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + use_8bit_adam, + xformers, +): + + original_file_path = file_path + file_path = get_file_path(file_path) + + if file_path != "" and file_path != None: + print(file_path) + # load variables from JSON file + with open(file_path, "r") as f: + my_data = json.load(f) + else: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data = {} # Return the values of the variables as a dictionary return ( - my_data.get("pretrained_model_name_or_path", None), - my_data.get("v2", None), - my_data.get("v_parameterization", None), - my_data.get("logging_dir", None), - my_data.get("train_data_dir", None), - my_data.get("reg_data_dir", None), - my_data.get("output_dir", None), - my_data.get("max_resolution", None), - my_data.get("learning_rate", None), - my_data.get("lr_scheduler", None), - my_data.get("lr_warmup", None), - my_data.get("train_batch_size", None), - my_data.get("epoch", None), - my_data.get("save_every_n_epochs", None), - my_data.get("mixed_precision", None), - my_data.get("save_precision", None), - my_data.get("seed", None), - my_data.get("num_cpu_threads_per_process", None), - my_data.get("convert_to_safetensors", None), - my_data.get("convert_to_ckpt", None), - my_data.get("cache_latent", None), - my_data.get("caption_extention", None), - my_data.get("use_safetensors", None), - my_data.get("enable_bucket", None), - my_data.get("gradient_checkpointing", None), - my_data.get("full_fp16", None), - my_data.get("no_token_padding", None), - my_data.get("stop_text_encoder_training", None), - my_data.get("use_8bit_adam", None), - my_data.get("xformers", None), + file_path, + my_data.get("pretrained_model_name_or_path", pretrained_model_name_or_path), + my_data.get("v2", v2), + my_data.get("v_parameterization", v_parameterization), + my_data.get("logging_dir", logging_dir), + my_data.get("train_data_dir", train_data_dir), + my_data.get("reg_data_dir", reg_data_dir), + my_data.get("output_dir", output_dir), + my_data.get("max_resolution", max_resolution), + my_data.get("learning_rate", learning_rate), + my_data.get("lr_scheduler", lr_scheduler), + my_data.get("lr_warmup", lr_warmup), + my_data.get("train_batch_size", train_batch_size), + my_data.get("epoch", epoch), + my_data.get("save_every_n_epochs", save_every_n_epochs), + my_data.get("mixed_precision", mixed_precision), + my_data.get("save_precision", save_precision), + my_data.get("seed", seed), + my_data.get("num_cpu_threads_per_process", num_cpu_threads_per_process), + my_data.get("convert_to_safetensors", convert_to_safetensors), + my_data.get("convert_to_ckpt", convert_to_ckpt), + my_data.get("cache_latent", cache_latent), + my_data.get("caption_extention", caption_extention), + my_data.get("use_safetensors", use_safetensors), + my_data.get("enable_bucket", enable_bucket), + my_data.get("gradient_checkpointing", gradient_checkpointing), + my_data.get("full_fp16", full_fp16), + my_data.get("no_token_padding", no_token_padding), + my_data.get("stop_text_encoder_training", stop_text_encoder_training), + my_data.get("use_8bit_adam", use_8bit_adam), + my_data.get("xformers", xformers), ) @@ -356,28 +499,49 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): return value, v2, v_parameterization + def remove_doublequote(file_path): if file_path != None: - file_path = file_path.replace('"', '') + file_path = file_path.replace('"', "") return file_path -interface = gr.Blocks() +def get_file_path(file_path): + file_path = fileopenbox( + "Select the config file to load", default=file_path, filetypes="*.json" + ) + + return file_path + + +def get_folder_path(): + folder_path = diropenbox("Select the directory to use") + + return folder_path + + +css = "" + +if os.path.exists("./style.css"): + with open(os.path.join("./style.css"), "r", encoding="utf8") as file: + print("Load CSS...") + css += file.read() + "\n" + +interface = gr.Blocks(css=css) with interface: gr.Markdown("Enter kohya finetuner parameter using this interface.") with gr.Accordion("Configuration File Load/Save", open=False): with gr.Row(): - config_file_name = gr.Textbox(label="Config file name") - button_load_config = gr.Button("Load config") - button_save_config = gr.Button("Save config") + button_open_config = gr.Button("Open 📂", elem_id="open_folder") + button_save_config = gr.Button("Save 💾", elem_id="open_folder") + button_save_as_config = gr.Button("Save as... 💾", elem_id="open_folder") + config_file_name = gr.Textbox( + label="", placeholder="type config file path or use buttons..." + ) config_file_name.change( - remove_doublequote, - inputs=[config_file_name], - outputs=[ - config_file_name - ] + remove_doublequote, inputs=[config_file_name], outputs=[config_file_name] ) with gr.Tab("Source model"): # Define the input elements @@ -406,9 +570,7 @@ with interface: pretrained_model_name_or_path_input.change( remove_doublequote, inputs=[pretrained_model_name_or_path_input], - outputs=[ - pretrained_model_name_or_path_input - ] + outputs=[pretrained_model_name_or_path_input], ) model_list.change( set_pretrained_model_name_or_path_input, @@ -419,53 +581,49 @@ with interface: v_parameterization_input, ], ) - + with gr.Tab("Directories"): with gr.Row(): train_data_dir_input = gr.Textbox( label="Image folder", placeholder="Directory where the training folders containing the images are located", ) + train_data_dir_input_folder = gr.Button("📂", elem_id="open_folder_small") + train_data_dir_input_folder.click(get_folder_path, outputs=train_data_dir_input) reg_data_dir_input = gr.Textbox( label="Regularisation folder", placeholder="(Optional) Directory where where the regularization folders containing the images are located", ) + reg_data_dir_input_folder = gr.Button("📂", elem_id="open_folder_small") + reg_data_dir_input_folder.click(get_folder_path, outputs=reg_data_dir_input) with gr.Row(): output_dir_input = gr.Textbox( label="Output directory", placeholder="Directory to output trained model", ) + output_dir_input_folder = gr.Button("📂", elem_id="open_folder_small") + output_dir_input_folder.click(get_folder_path, outputs=output_dir_input) logging_dir_input = gr.Textbox( label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory", ) + logging_dir_input_folder = gr.Button("📂", elem_id="open_folder_small") + logging_dir_input_folder.click(get_folder_path, outputs=logging_dir_input) train_data_dir_input.change( remove_doublequote, inputs=[train_data_dir_input], - outputs=[ - train_data_dir_input - ] + outputs=[train_data_dir_input], ) reg_data_dir_input.change( remove_doublequote, inputs=[reg_data_dir_input], - outputs=[ - reg_data_dir_input - ] + outputs=[reg_data_dir_input], ) output_dir_input.change( - remove_doublequote, - inputs=[output_dir_input], - outputs=[ - output_dir_input - ] + remove_doublequote, inputs=[output_dir_input], outputs=[output_dir_input] ) logging_dir_input.change( - remove_doublequote, - inputs=[logging_dir_input], - outputs=[ - logging_dir_input - ] + remove_doublequote, inputs=[logging_dir_input], outputs=[logging_dir_input] ) with gr.Tab("Training parameters"): with gr.Row(): @@ -523,7 +681,11 @@ with interface: label="Caption Extension", placeholder="(Optional) Extension for caption files. default: .caption", ) - stop_text_encoder_training_input = gr.Slider(minimum=0, maximum=100, value=0, step=1, + stop_text_encoder_training_input = gr.Slider( + minimum=0, + maximum=100, + value=0, + step=1, label="Stop text encoder training", ) with gr.Row(): @@ -551,10 +713,43 @@ with interface: button_run = gr.Button("Run") - button_load_config.click( - load_variables, - inputs=[config_file_name], + button_open_config.click( + open_config_file, + inputs=[ + config_file_name, + pretrained_model_name_or_path_input, + v2_input, + v_parameterization_input, + logging_dir_input, + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + max_resolution_input, + learning_rate_input, + lr_scheduler_input, + lr_warmup_input, + train_batch_size_input, + epoch_input, + save_every_n_epochs_input, + mixed_precision_input, + save_precision_input, + seed_input, + num_cpu_threads_per_process_input, + convert_to_safetensors_input, + convert_to_ckpt_input, + cache_latent_input, + caption_extention_input, + use_safetensors_input, + enable_bucket_input, + gradient_checkpointing_input, + full_fp16_input, + no_token_padding_input, + stop_text_encoder_training_input, + use_8bit_adam_input, + xformers_input, + ], outputs=[ + config_file_name, pretrained_model_name_or_path_input, v2_input, v_parameterization_input, @@ -623,7 +818,47 @@ with interface: use_8bit_adam_input, xformers_input, ], + outputs=[config_file_name], ) + + button_save_as_config.click( + save_as_variables, + inputs=[ + config_file_name, + pretrained_model_name_or_path_input, + v2_input, + v_parameterization_input, + logging_dir_input, + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + max_resolution_input, + learning_rate_input, + lr_scheduler_input, + lr_warmup_input, + train_batch_size_input, + epoch_input, + save_every_n_epochs_input, + mixed_precision_input, + save_precision_input, + seed_input, + num_cpu_threads_per_process_input, + convert_to_safetensors_input, + convert_to_ckpt_input, + cache_latent_input, + caption_extention_input, + use_safetensors_input, + enable_bucket_input, + gradient_checkpointing_input, + full_fp16_input, + no_token_padding_input, + stop_text_encoder_training_input, + use_8bit_adam_input, + xformers_input, + ], + outputs=[config_file_name], + ) + button_run.click( train_model, inputs=[ diff --git a/requirements.txt b/requirements.txt index cfb2bdb..4e63bbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ bitsandbytes==0.35.0 tensorboard safetensors==0.2.6 gradio -altair \ No newline at end of file +altair +easygui \ No newline at end of file diff --git a/style.css b/style.css new file mode 100644 index 0000000..6545eec --- /dev/null +++ b/style.css @@ -0,0 +1,14 @@ +#open_folder_small{ + height: fit-content; + min-width: auto; + flex-grow: 0; + padding-left: 0.25em; + padding-right: 0.25em; +} + +#open_folder{ + height: fit-content; + flex-grow: 0; + padding-left: 0.25em; + padding-right: 0.25em; +} \ No newline at end of file