Update UI code

This commit is contained in:
bmaltais 2022-12-15 19:04:26 -05:00
parent a7cd798abf
commit dddcffbe51

View File

@ -244,6 +244,27 @@ def train_model(
f"{output_dir}/last.yaml", f"{output_dir}/last.yaml",
) )
if pretrained_model_name_or_path == "":
msgbox("Source model information is missing")
return
if train_data_dir == "":
msgbox("Image folder path is missing")
return
if not os.path.exists(train_data_dir):
msgbox("Image folder does not exist")
return
if reg_data_dir != "":
if not os.path.exists(reg_data_dir):
msgbox("Regularisation folder does not exist")
return
if output_dir == "":
msgbox("Output folder path is missing")
return
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir
subfolders = [ subfolders = [
f for f in os.listdir(train_data_dir) f for f in os.listdir(train_data_dir)
@ -701,7 +722,7 @@ with interface:
with gr.Row(): with gr.Row():
seed_input = gr.Textbox(label="Seed", value=1234) seed_input = gr.Textbox(label="Seed", value=1234)
max_resolution_input = gr.Textbox(label="Max resolution", max_resolution_input = gr.Textbox(label="Max resolution",
value="512,512") value="512,512", placeholder="512,512")
with gr.Row(): with gr.Row():
caption_extention_input = gr.Textbox( caption_extention_input = gr.Textbox(
label="Caption Extension", label="Caption Extension",
@ -715,26 +736,27 @@ with interface:
step=1, step=1,
label="Stop text encoder training", label="Stop text encoder training",
) )
with gr.Row():
use_safetensors_input = gr.Checkbox(
label="Use safetensor when saving", value=False)
enable_bucket_input = gr.Checkbox(label="Enable buckets",
value=True)
cache_latent_input = gr.Checkbox(label="Cache latent", value=True)
gradient_checkpointing_input = gr.Checkbox(
label="Gradient checkpointing", value=False)
with gr.Row(): with gr.Row():
full_fp16_input = gr.Checkbox( full_fp16_input = gr.Checkbox(
label="Full fp16 training (experimental)", value=False) label="Full fp16 training (experimental)", value=False)
no_token_padding_input = gr.Checkbox(label="No token padding", no_token_padding_input = gr.Checkbox(label="No token padding",
value=False) value=False)
use_safetensors_input = gr.Checkbox(
label="Use safetensor when saving", value=False)
gradient_checkpointing_input = gr.Checkbox(
label="Gradient checkpointing", value=False)
with gr.Row():
enable_bucket_input = gr.Checkbox(label="Enable buckets",
value=True)
cache_latent_input = gr.Checkbox(label="Cache latent", value=True)
use_8bit_adam_input = gr.Checkbox(label="Use 8bit adam", use_8bit_adam_input = gr.Checkbox(label="Use 8bit adam",
value=True) value=True)
xformers_input = gr.Checkbox(label="Use xformers", value=True) xformers_input = gr.Checkbox(label="Use xformers", value=True)
with gr.Tab("Model conversion"): with gr.Tab("Model conversion"):
convert_to_safetensors_input = gr.Checkbox( convert_to_safetensors_input = gr.Checkbox(
label="Convert to SafeTensors", value=False) label="Convert to SafeTensors", value=True)
convert_to_ckpt_input = gr.Checkbox(label="Convert to CKPT", convert_to_ckpt_input = gr.Checkbox(label="Convert to CKPT",
value=False) value=False)