From 0d52ff4842b9e84485bed2379b66c4482da1d745 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 13 Dec 2022 11:07:32 -0500 Subject: [PATCH] Add support for more options, rework UI --- finetune_gui.py | 149 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 103 insertions(+), 46 deletions(-) diff --git a/finetune_gui.py b/finetune_gui.py index b24d0cb..914905b 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -30,7 +30,11 @@ def save_variables( seed, num_cpu_threads_per_process, convert_to_safetensors, - convert_to_ckpt + convert_to_ckpt, + cache_latent, + caption_extention, + use_safetensors, + enable_bucket ): # Return the values of the variables as a dictionary variables = { @@ -54,7 +58,11 @@ def save_variables( "seed": seed, "num_cpu_threads_per_process": num_cpu_threads_per_process, "convert_to_safetensors": convert_to_safetensors, - "convert_to_ckpt": convert_to_ckpt + "convert_to_ckpt": convert_to_ckpt, + "cache_latent": cache_latent, + "caption_extention": caption_extention, + "use_safetensors": use_safetensors, + "enable_bucket": enable_bucket } # Save the data to the selected file @@ -73,7 +81,6 @@ def load_variables(file_path): my_data.get("v2", None), my_data.get("v_model", None), my_data.get("logging_dir", None), - # my_data.get("model_list", None), my_data.get("train_data_dir", None), my_data.get("reg_data_dir", None), my_data.get("output_dir", None), @@ -89,7 +96,11 @@ def load_variables(file_path): my_data.get("seed", None), my_data.get("num_cpu_threads_per_process", None), my_data.get("convert_to_safetensors", None), - my_data.get("convert_to_ckpt", None) + my_data.get("convert_to_ckpt", None), + my_data.get("cache_latent", None), + my_data.get("caption_extention", None), + my_data.get("use_safetensors", None), + my_data.get("enable_bucket", None), ) @@ -114,7 +125,10 @@ def train_model( num_cpu_threads_per_process, convert_to_safetensors, convert_to_ckpt, - cache_latent_input + cache_latent, + caption_extention, + use_safetensors, + enable_bucket ): def save_inference_file(output_dir, v2, v_model): # Copy inference model for v2 if required @@ -170,8 +184,12 @@ def train_model( run_cmd += " --v2" if v_model: run_cmd += " --v_parameterization" - if cache_latent_input: + if cache_latent: run_cmd += " --cache_latents" + if use_safetensors: + run_cmd += " --use_safetensors" + if enable_bucket: + run_cmd += " --enable_bucket" run_cmd += f" --pretrained_model_name_or_path={pretrained_model_name_or_path}" run_cmd += f" --train_data_dir={train_data_dir}" run_cmd += f" --reg_data_dir={reg_data_dir}" @@ -189,6 +207,7 @@ def train_model( run_cmd += f" --seed={seed}" run_cmd += f" --save_precision={save_precision}" run_cmd += f" --logging_dir={logging_dir}" + run_cmd += f" --caption_extention={caption_extention}" print(run_cmd) # Run the command @@ -207,7 +226,8 @@ def train_model( save_inference_file(output_dir, v2, v_model) if convert_to_safetensors: - print(f"Converting diffuser model {last_dir} to {last_dir}.safetensors") + print( + f"Converting diffuser model {last_dir} to {last_dir}.safetensors") os.system( f"python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.safetensors --{save_precision}" ) @@ -223,26 +243,33 @@ def train_model( def set_pretrained_model_name_or_path_input(value, v2, v_model): # define a list of substrings to search for - substrings_v2 = ["stable-diffusion-2-1-base", "stable-diffusion-2-base"] + substrings_v2 = ["stabilityai/stable-diffusion-2-1-base", "stabilityai/stable-diffusion-2-base"] # check if $v2 and $v_model are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list if str(value) in substrings_v2: print("SD v2 model detected. Setting --v2 parameter") v2 = True v_model = False - value = "stabilityai/{}".format(value) return value, v2, v_model # define a list of substrings to search for v-objective - substrings_v_model = ["stable-diffusion-2-1", "stable-diffusion-2"] + substrings_v_model = ["stabilityai/stable-diffusion-2-1", "stabilityai/stable-diffusion-2"] # check if $v2 and $v_model are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_model list if str(value) in substrings_v_model: print("SD v2 v_model detected. Setting --v2 parameter and --v_parameterization") v2 = True v_model = True - value = "stabilityai/{}".format(value) + + return value, v2, v_model + + # define a list of substrings to v1.x + substrings_v1_model = ["CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"] + + if str(value) in substrings_v1_model: + v2 = False + v_model = False return value, v2, v_model @@ -263,10 +290,11 @@ with interface: gr.Markdown("Enter kohya finetuner parameter using this interface.") with gr.Accordion("Configuration File Load/Save", open=False): with gr.Row(): - config_file_name = gr.inputs.Textbox(label="Config file name", default="") + config_file_name = gr.inputs.Textbox( + label="Config file name", default="") b1 = gr.Button("Load config") b2 = gr.Button("Save config") - with gr.Tab("model"): + with gr.Tab("Source model"): # Define the input elements with gr.Row(): pretrained_model_name_or_path_input = gr.inputs.Textbox( @@ -277,10 +305,12 @@ with interface: label="Model Quick Pick", choices=[ "custom", - "stable-diffusion-2-1-base", - "stable-diffusion-2-base", - "stable-diffusion-2-1", - "stable-diffusion-2", + "stabilityai/stable-diffusion-2-1-base", + "stabilityai/stable-diffusion-2-base", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-2", + "runwayml/stable-diffusion-v1-5", + "CompVis/stable-diffusion-v1-4" ], value="custom", ) @@ -290,28 +320,30 @@ with interface: model_list.change( set_pretrained_model_name_or_path_input, inputs=[model_list, v2_input, v_model_input], - outputs=[pretrained_model_name_or_path_input, v2_input, v_model_input], + outputs=[pretrained_model_name_or_path_input, + v2_input, v_model_input], ) - with gr.Tab("training dataset and output directory"): - train_data_dir_input = gr.inputs.Textbox( - label="Image folder", placeholder="directory where the training folders containing the images are located" - ) - reg_data_dir_input = gr.inputs.Textbox( - label="Regularisation folder", placeholder="directory where where the regularization folders containing the images are located" - ) - output_dir_input = gr.inputs.Textbox( - label="Output directory", - placeholder="directory to output trained model", - ) - logging_dir_input = gr.inputs.Textbox( - label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory" - ) - max_resolution_input = gr.inputs.Textbox( - label="Max resolution", default="512,512" - ) - with gr.Tab("training parameters"): + with gr.Tab("Directories"): with gr.Row(): - learning_rate_input = gr.inputs.Textbox(label="Learning rate", default=1e-6) + train_data_dir_input = gr.inputs.Textbox( + label="Image folder", placeholder="directory where the training folders containing the images are located" + ) + reg_data_dir_input = gr.inputs.Textbox( + label="Regularisation folder", placeholder="directory where where the regularization folders containing the images are located" + ) + + with gr.Row(): + output_dir_input = gr.inputs.Textbox( + label="Output directory", + placeholder="directory to output trained model", + ) + logging_dir_input = gr.inputs.Textbox( + label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory" + ) + with gr.Tab("Training parameters"): + with gr.Row(): + learning_rate_input = gr.inputs.Textbox( + label="Learning rate", default=1e-6) lr_scheduler_input = gr.Dropdown( label="LR Scheduler", choices=[ @@ -330,10 +362,10 @@ with interface: label="Train batch size", default=1 ) epoch_input = gr.inputs.Textbox(label="Epoch", default=1) - with gr.Row(): save_every_n_epochs_input = gr.inputs.Textbox( label="Save every N epochs", default=1 ) + with gr.Row(): mixed_precision_input = gr.Dropdown( label="Mixed precision", choices=[ @@ -352,22 +384,36 @@ with interface: ], value="fp16", ) - with gr.Row(): - seed_input = gr.inputs.Textbox(label="Seed", default=1234) num_cpu_threads_per_process_input = gr.inputs.Textbox( label="Number of CPU threads per process", default=4 ) + with gr.Row(): + seed_input = gr.inputs.Textbox(label="Seed", default=1234) + max_resolution_input = gr.inputs.Textbox( + label="Max resolution", default="512,512" + ) + caption_extention_input = gr.inputs.Textbox( + label="Caption Extension", placeholder="(Optional) Extension for caption files. default: .caption") + + with gr.Row(): + use_safetensors_input = gr.inputs.Checkbox( + label="Use safetensor when saving checkpoint", default=False + ) + enable_bucket_input = gr.inputs.Checkbox( + label="Enable buckets", default=False + ) cache_latent_input = gr.inputs.Checkbox( label="Cache latent", default=True ) - with gr.Tab("model conveersion"): + + with gr.Tab("Model conversion"): convert_to_safetensors_input = gr.inputs.Checkbox( label="Convert to SafeTensors", default=False ) convert_to_ckpt_input = gr.inputs.Checkbox( label="Convert to CKPT", default=False ) - + b3 = gr.Button("Run") b1.click( @@ -393,10 +439,14 @@ with interface: seed_input, num_cpu_threads_per_process_input, convert_to_safetensors_input, - convert_to_ckpt_input + convert_to_ckpt_input, + cache_latent_input, + caption_extention_input, + use_safetensors_input, + enable_bucket_input ] ) - + b2.click( save_variables, inputs=[ @@ -420,7 +470,11 @@ with interface: seed_input, num_cpu_threads_per_process_input, convert_to_safetensors_input, - convert_to_ckpt_input + convert_to_ckpt_input, + cache_latent_input, + caption_extention_input, + use_safetensors_input, + enable_bucket_input ] ) b3.click( @@ -446,7 +500,10 @@ with interface: num_cpu_threads_per_process_input, convert_to_safetensors_input, convert_to_ckpt_input, - cache_latent_input + cache_latent_input, + caption_extention_input, + use_safetensors_input, + enable_bucket_input ] )