Fix deprecation warning

This commit is contained in:
bmaltais 2022-12-13 11:26:21 -05:00
parent 0d52ff4842
commit 01eb9486d3

View File

@ -41,7 +41,6 @@ def save_variables(
"pretrained_model_name_or_path": pretrained_model_name_or_path, "pretrained_model_name_or_path": pretrained_model_name_or_path,
"v2": v2, "v2": v2,
"v_model": v_model, "v_model": v_model,
# "model_list": model_list,
"logging_dir": logging_dir, "logging_dir": logging_dir,
"train_data_dir": train_data_dir, "train_data_dir": train_data_dir,
"reg_data_dir": reg_data_dir, "reg_data_dir": reg_data_dir,
@ -280,29 +279,25 @@ def set_pretrained_model_name_or_path_input(value, v2, v_model):
return value, v2, v_model return value, v2, v_model
# Define the output element
output = gr.outputs.Textbox(label="Values of variables")
interface = gr.Blocks() interface = gr.Blocks()
with interface: with interface:
gr.Markdown("Enter kohya finetuner parameter using this interface.") gr.Markdown("Enter kohya finetuner parameter using this interface.")
with gr.Accordion("Configuration File Load/Save", open=False): with gr.Accordion("Configuration File Load/Save", open=False):
with gr.Row(): with gr.Row():
config_file_name = gr.inputs.Textbox( config_file_name = gr.Textbox(
label="Config file name", default="") label="Config file name")
b1 = gr.Button("Load config") b1 = gr.Button("Load config")
b2 = gr.Button("Save config") b2 = gr.Button("Save config")
with gr.Tab("Source model"): with gr.Tab("Source model"):
# Define the input elements # Define the input elements
with gr.Row(): with gr.Row():
pretrained_model_name_or_path_input = gr.inputs.Textbox( pretrained_model_name_or_path_input = gr.Textbox(
label="Pretrained model name or path", label="Pretrained model name or path",
placeholder="enter the path to custom model or name of pretrained model", placeholder="enter the path to custom model or name of pretrained model",
) )
model_list = gr.Dropdown( model_list = gr.Dropdown(
label="Model Quick Pick", label="(Optional) Model Quick Pick",
choices=[ choices=[
"custom", "custom",
"stabilityai/stable-diffusion-2-1-base", "stabilityai/stable-diffusion-2-1-base",
@ -312,11 +307,10 @@ with interface:
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4" "CompVis/stable-diffusion-v1-4"
], ],
value="custom",
) )
with gr.Row(): with gr.Row():
v2_input = gr.inputs.Checkbox(label="v2", default=True) v2_input = gr.Checkbox(label="v2", value=True)
v_model_input = gr.inputs.Checkbox(label="v_model", default=False) v_model_input = gr.Checkbox(label="v_model", value=False)
model_list.change( model_list.change(
set_pretrained_model_name_or_path_input, set_pretrained_model_name_or_path_input,
inputs=[model_list, v2_input, v_model_input], inputs=[model_list, v2_input, v_model_input],
@ -325,25 +319,25 @@ with interface:
) )
with gr.Tab("Directories"): with gr.Tab("Directories"):
with gr.Row(): with gr.Row():
train_data_dir_input = gr.inputs.Textbox( train_data_dir_input = gr.Textbox(
label="Image folder", placeholder="directory where the training folders containing the images are located" label="Image folder", placeholder="directory where the training folders containing the images are located"
) )
reg_data_dir_input = gr.inputs.Textbox( reg_data_dir_input = gr.Textbox(
label="Regularisation folder", placeholder="directory where where the regularization folders containing the images are located" label="Regularisation folder", placeholder="directory where where the regularization folders containing the images are located"
) )
with gr.Row(): with gr.Row():
output_dir_input = gr.inputs.Textbox( output_dir_input = gr.Textbox(
label="Output directory", label="Output directory",
placeholder="directory to output trained model", placeholder="directory to output trained model",
) )
logging_dir_input = gr.inputs.Textbox( logging_dir_input = gr.Textbox(
label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory" label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory"
) )
with gr.Tab("Training parameters"): with gr.Tab("Training parameters"):
with gr.Row(): with gr.Row():
learning_rate_input = gr.inputs.Textbox( learning_rate_input = gr.Textbox(
label="Learning rate", default=1e-6) label="Learning rate", value=1e-6)
lr_scheduler_input = gr.Dropdown( lr_scheduler_input = gr.Dropdown(
label="LR Scheduler", label="LR Scheduler",
choices=[ choices=[
@ -356,14 +350,14 @@ with interface:
], ],
value="constant", value="constant",
) )
lr_warmup_input = gr.inputs.Textbox(label="LR warmup", default=0) lr_warmup_input = gr.Textbox(label="LR warmup", value=0)
with gr.Row(): with gr.Row():
train_batch_size_input = gr.inputs.Textbox( train_batch_size_input = gr.Textbox(
label="Train batch size", default=1 label="Train batch size", value=1
) )
epoch_input = gr.inputs.Textbox(label="Epoch", default=1) epoch_input = gr.Textbox(label="Epoch", value=1)
save_every_n_epochs_input = gr.inputs.Textbox( save_every_n_epochs_input = gr.Textbox(
label="Save every N epochs", default=1 label="Save every N epochs", value=1
) )
with gr.Row(): with gr.Row():
mixed_precision_input = gr.Dropdown( mixed_precision_input = gr.Dropdown(
@ -384,34 +378,34 @@ with interface:
], ],
value="fp16", value="fp16",
) )
num_cpu_threads_per_process_input = gr.inputs.Textbox( num_cpu_threads_per_process_input = gr.Textbox(
label="Number of CPU threads per process", default=4 label="Number of CPU threads per process", value=4
) )
with gr.Row(): with gr.Row():
seed_input = gr.inputs.Textbox(label="Seed", default=1234) seed_input = gr.Textbox(label="Seed", value=1234)
max_resolution_input = gr.inputs.Textbox( max_resolution_input = gr.Textbox(
label="Max resolution", default="512,512" label="Max resolution", value="512,512"
) )
caption_extention_input = gr.inputs.Textbox( caption_extention_input = gr.Textbox(
label="Caption Extension", placeholder="(Optional) Extension for caption files. default: .caption") label="Caption Extension", placeholder="(Optional) Extension for caption files. default: .caption")
with gr.Row(): with gr.Row():
use_safetensors_input = gr.inputs.Checkbox( use_safetensors_input = gr.Checkbox(
label="Use safetensor when saving checkpoint", default=False label="Use safetensor when saving checkpoint", value=False
) )
enable_bucket_input = gr.inputs.Checkbox( enable_bucket_input = gr.Checkbox(
label="Enable buckets", default=False label="Enable buckets", value=False
) )
cache_latent_input = gr.inputs.Checkbox( cache_latent_input = gr.Checkbox(
label="Cache latent", default=True label="Cache latent", value=True
) )
with gr.Tab("Model conversion"): with gr.Tab("Model conversion"):
convert_to_safetensors_input = gr.inputs.Checkbox( convert_to_safetensors_input = gr.Checkbox(
label="Convert to SafeTensors", default=False label="Convert to SafeTensors", value=False
) )
convert_to_ckpt_input = gr.inputs.Checkbox( convert_to_ckpt_input = gr.Checkbox(
label="Convert to CKPT", default=False label="Convert to CKPT", value=False
) )
b3 = gr.Button("Run") b3 = gr.Button("Run")