Fix deprecation warning

This commit is contained in:
bmaltais 2022-12-13 11:26:21 -05:00
parent 0d52ff4842
commit 01eb9486d3

View File

@ -41,7 +41,6 @@ def save_variables(
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"v2": v2,
"v_model": v_model,
# "model_list": model_list,
"logging_dir": logging_dir,
"train_data_dir": train_data_dir,
"reg_data_dir": reg_data_dir,
@ -280,29 +279,25 @@ def set_pretrained_model_name_or_path_input(value, v2, v_model):
return value, v2, v_model
# Define the output element
output = gr.outputs.Textbox(label="Values of variables")
interface = gr.Blocks()
with interface:
gr.Markdown("Enter kohya finetuner parameter using this interface.")
with gr.Accordion("Configuration File Load/Save", open=False):
with gr.Row():
config_file_name = gr.inputs.Textbox(
label="Config file name", default="")
config_file_name = gr.Textbox(
label="Config file name")
b1 = gr.Button("Load config")
b2 = gr.Button("Save config")
with gr.Tab("Source model"):
# Define the input elements
with gr.Row():
pretrained_model_name_or_path_input = gr.inputs.Textbox(
pretrained_model_name_or_path_input = gr.Textbox(
label="Pretrained model name or path",
placeholder="enter the path to custom model or name of pretrained model",
)
model_list = gr.Dropdown(
label="Model Quick Pick",
label="(Optional) Model Quick Pick",
choices=[
"custom",
"stabilityai/stable-diffusion-2-1-base",
@ -312,11 +307,10 @@ with interface:
"runwayml/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4"
],
value="custom",
)
with gr.Row():
v2_input = gr.inputs.Checkbox(label="v2", default=True)
v_model_input = gr.inputs.Checkbox(label="v_model", default=False)
v2_input = gr.Checkbox(label="v2", value=True)
v_model_input = gr.Checkbox(label="v_model", value=False)
model_list.change(
set_pretrained_model_name_or_path_input,
inputs=[model_list, v2_input, v_model_input],
@ -325,25 +319,25 @@ with interface:
)
with gr.Tab("Directories"):
with gr.Row():
train_data_dir_input = gr.inputs.Textbox(
train_data_dir_input = gr.Textbox(
label="Image folder", placeholder="directory where the training folders containing the images are located"
)
reg_data_dir_input = gr.inputs.Textbox(
reg_data_dir_input = gr.Textbox(
label="Regularisation folder", placeholder="directory where where the regularization folders containing the images are located"
)
with gr.Row():
output_dir_input = gr.inputs.Textbox(
output_dir_input = gr.Textbox(
label="Output directory",
placeholder="directory to output trained model",
)
logging_dir_input = gr.inputs.Textbox(
logging_dir_input = gr.Textbox(
label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory"
)
with gr.Tab("Training parameters"):
with gr.Row():
learning_rate_input = gr.inputs.Textbox(
label="Learning rate", default=1e-6)
learning_rate_input = gr.Textbox(
label="Learning rate", value=1e-6)
lr_scheduler_input = gr.Dropdown(
label="LR Scheduler",
choices=[
@ -356,14 +350,14 @@ with interface:
],
value="constant",
)
lr_warmup_input = gr.inputs.Textbox(label="LR warmup", default=0)
lr_warmup_input = gr.Textbox(label="LR warmup", value=0)
with gr.Row():
train_batch_size_input = gr.inputs.Textbox(
label="Train batch size", default=1
train_batch_size_input = gr.Textbox(
label="Train batch size", value=1
)
epoch_input = gr.inputs.Textbox(label="Epoch", default=1)
save_every_n_epochs_input = gr.inputs.Textbox(
label="Save every N epochs", default=1
epoch_input = gr.Textbox(label="Epoch", value=1)
save_every_n_epochs_input = gr.Textbox(
label="Save every N epochs", value=1
)
with gr.Row():
mixed_precision_input = gr.Dropdown(
@ -384,34 +378,34 @@ with interface:
],
value="fp16",
)
num_cpu_threads_per_process_input = gr.inputs.Textbox(
label="Number of CPU threads per process", default=4
num_cpu_threads_per_process_input = gr.Textbox(
label="Number of CPU threads per process", value=4
)
with gr.Row():
seed_input = gr.inputs.Textbox(label="Seed", default=1234)
max_resolution_input = gr.inputs.Textbox(
label="Max resolution", default="512,512"
seed_input = gr.Textbox(label="Seed", value=1234)
max_resolution_input = gr.Textbox(
label="Max resolution", value="512,512"
)
caption_extention_input = gr.inputs.Textbox(
caption_extention_input = gr.Textbox(
label="Caption Extension", placeholder="(Optional) Extension for caption files. default: .caption")
with gr.Row():
use_safetensors_input = gr.inputs.Checkbox(
label="Use safetensor when saving checkpoint", default=False
use_safetensors_input = gr.Checkbox(
label="Use safetensor when saving checkpoint", value=False
)
enable_bucket_input = gr.inputs.Checkbox(
label="Enable buckets", default=False
enable_bucket_input = gr.Checkbox(
label="Enable buckets", value=False
)
cache_latent_input = gr.inputs.Checkbox(
label="Cache latent", default=True
cache_latent_input = gr.Checkbox(
label="Cache latent", value=True
)
with gr.Tab("Model conversion"):
convert_to_safetensors_input = gr.inputs.Checkbox(
label="Convert to SafeTensors", default=False
convert_to_safetensors_input = gr.Checkbox(
label="Convert to SafeTensors", value=False
)
convert_to_ckpt_input = gr.inputs.Checkbox(
label="Convert to CKPT", default=False
convert_to_ckpt_input = gr.Checkbox(
label="Convert to CKPT", value=False
)
b3 = gr.Button("Run")