Remove redundant code

This commit is contained in:
bmaltais 2022-12-15 07:48:29 -05:00
parent 2b5c7312f1
commit 969eea519f
2 changed files with 29 additions and 94 deletions

View File

@ -116,7 +116,7 @@ python .\dreambooth_gui.py
## Support ## Support
Drop by the discord server for support: https://discord.com/channels/1023277529424986162/1026874833193140285 Drop by the discord server for support: https://discord.com/channels/1041518562487058594/1041518563242020906
## Manual Script Execution ## Manual Script Execution

View File

@ -13,7 +13,8 @@ from os.path import join
from easygui import fileopenbox, filesavebox, enterbox, diropenbox, msgbox from easygui import fileopenbox, filesavebox, enterbox, diropenbox, msgbox
def save_variables( def save_configuration(
save_as,
file_path, file_path,
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
@ -48,6 +49,17 @@ def save_variables(
): ):
original_file_path = file_path original_file_path = file_path
save_as_bool = True if save_as.get("label") == "True" else False
if save_as_bool:
print("Save as...")
file_path = filesavebox(
"Select the config file to save",
default="finetune.json",
filetypes="*.json",
)
else:
print("Save...")
if file_path == None or file_path == "": if file_path == None or file_path == "":
file_path = filesavebox( file_path = filesavebox(
"Select the config file to save", "Select the config file to save",
@ -56,8 +68,7 @@ def save_variables(
) )
if file_path == None: if file_path == None:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action return original_file_path # In case a file_path was provided and the user decide to cancel the open action
return file_path
# Return the values of the variables as a dictionary # Return the values of the variables as a dictionary
variables = { variables = {
@ -100,91 +111,7 @@ def save_variables(
return file_path return file_path
def save_as_variables( def open_configuration(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization,
logging_dir,
train_data_dir,
reg_data_dir,
output_dir,
max_resolution,
learning_rate,
lr_scheduler,
lr_warmup,
train_batch_size,
epoch,
save_every_n_epochs,
mixed_precision,
save_precision,
seed,
num_cpu_threads_per_process,
convert_to_safetensors,
convert_to_ckpt,
cache_latent,
caption_extention,
use_safetensors,
enable_bucket,
gradient_checkpointing,
full_fp16,
no_token_padding,
stop_text_encoder_training,
use_8bit_adam,
xformers,
):
original_file_path = file_path
file_path = filesavebox(
"Select the config file to save", default="finetune.json", filetypes="*.json"
)
if file_path == None:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
return file_path
# Return the values of the variables as a dictionary
variables = {
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"v2": v2,
"v_parameterization": v_parameterization,
"logging_dir": logging_dir,
"train_data_dir": train_data_dir,
"reg_data_dir": reg_data_dir,
"output_dir": output_dir,
"max_resolution": max_resolution,
"learning_rate": learning_rate,
"lr_scheduler": lr_scheduler,
"lr_warmup": lr_warmup,
"train_batch_size": train_batch_size,
"epoch": epoch,
"save_every_n_epochs": save_every_n_epochs,
"mixed_precision": mixed_precision,
"save_precision": save_precision,
"seed": seed,
"num_cpu_threads_per_process": num_cpu_threads_per_process,
"convert_to_safetensors": convert_to_safetensors,
"convert_to_ckpt": convert_to_ckpt,
"cache_latent": cache_latent,
"caption_extention": caption_extention,
"use_safetensors": use_safetensors,
"enable_bucket": enable_bucket,
"gradient_checkpointing": gradient_checkpointing,
"full_fp16": full_fp16,
"no_token_padding": no_token_padding,
"stop_text_encoder_training": stop_text_encoder_training,
"use_8bit_adam": use_8bit_adam,
"xformers": xformers,
}
# Save the data to the selected file
with open(file_path, "w") as file:
json.dump(variables, file)
return file_path
def open_config_file(
file_path, file_path,
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
@ -531,6 +458,8 @@ if os.path.exists("./style.css"):
interface = gr.Blocks(css=css) interface = gr.Blocks(css=css)
with interface: with interface:
dummy_true = gr.Label(value=True, visible=False)
dummy_false = gr.Label(value=False, visible=False)
gr.Markdown("Enter kohya finetuner parameter using this interface.") gr.Markdown("Enter kohya finetuner parameter using this interface.")
with gr.Accordion("Configuration File Load/Save", open=False): with gr.Accordion("Configuration File Load/Save", open=False):
with gr.Row(): with gr.Row():
@ -589,7 +518,9 @@ with interface:
placeholder="Directory where the training folders containing the images are located", placeholder="Directory where the training folders containing the images are located",
) )
train_data_dir_input_folder = gr.Button("📂", elem_id="open_folder_small") train_data_dir_input_folder = gr.Button("📂", elem_id="open_folder_small")
train_data_dir_input_folder.click(get_folder_path, outputs=train_data_dir_input) train_data_dir_input_folder.click(
get_folder_path, outputs=train_data_dir_input
)
reg_data_dir_input = gr.Textbox( reg_data_dir_input = gr.Textbox(
label="Regularisation folder", label="Regularisation folder",
placeholder="(Optional) Directory where where the regularization folders containing the images are located", placeholder="(Optional) Directory where where the regularization folders containing the images are located",
@ -714,7 +645,7 @@ with interface:
button_run = gr.Button("Run") button_run = gr.Button("Run")
button_open_config.click( button_open_config.click(
open_config_file, open_configuration,
inputs=[ inputs=[
config_file_name, config_file_name,
pretrained_model_name_or_path_input, pretrained_model_name_or_path_input,
@ -783,9 +714,12 @@ with interface:
], ],
) )
save_as = True
not_save_as = False
button_save_config.click( button_save_config.click(
save_variables, save_configuration,
inputs=[ inputs=[
dummy_false,
config_file_name, config_file_name,
pretrained_model_name_or_path_input, pretrained_model_name_or_path_input,
v2_input, v2_input,
@ -822,8 +756,9 @@ with interface:
) )
button_save_as_config.click( button_save_as_config.click(
save_as_variables, save_configuration,
inputs=[ inputs=[
dummy_true,
config_file_name, config_file_name,
pretrained_model_name_or_path_input, pretrained_model_name_or_path_input,
v2_input, v2_input,