Update UI code
This commit is contained in:
parent
a7cd798abf
commit
dddcffbe51
@ -244,6 +244,27 @@ def train_model(
|
||||
f"{output_dir}/last.yaml",
|
||||
)
|
||||
|
||||
if pretrained_model_name_or_path == "":
|
||||
msgbox("Source model information is missing")
|
||||
return
|
||||
|
||||
if train_data_dir == "":
|
||||
msgbox("Image folder path is missing")
|
||||
return
|
||||
|
||||
if not os.path.exists(train_data_dir):
|
||||
msgbox("Image folder does not exist")
|
||||
return
|
||||
|
||||
if reg_data_dir != "":
|
||||
if not os.path.exists(reg_data_dir):
|
||||
msgbox("Regularisation folder does not exist")
|
||||
return
|
||||
|
||||
if output_dir == "":
|
||||
msgbox("Output folder path is missing")
|
||||
return
|
||||
|
||||
# Get a list of all subfolders in train_data_dir
|
||||
subfolders = [
|
||||
f for f in os.listdir(train_data_dir)
|
||||
@ -701,7 +722,7 @@ with interface:
|
||||
with gr.Row():
|
||||
seed_input = gr.Textbox(label="Seed", value=1234)
|
||||
max_resolution_input = gr.Textbox(label="Max resolution",
|
||||
value="512,512")
|
||||
value="512,512", placeholder="512,512")
|
||||
with gr.Row():
|
||||
caption_extention_input = gr.Textbox(
|
||||
label="Caption Extension",
|
||||
@ -715,26 +736,27 @@ with interface:
|
||||
step=1,
|
||||
label="Stop text encoder training",
|
||||
)
|
||||
with gr.Row():
|
||||
use_safetensors_input = gr.Checkbox(
|
||||
label="Use safetensor when saving", value=False)
|
||||
enable_bucket_input = gr.Checkbox(label="Enable buckets",
|
||||
value=True)
|
||||
cache_latent_input = gr.Checkbox(label="Cache latent", value=True)
|
||||
gradient_checkpointing_input = gr.Checkbox(
|
||||
label="Gradient checkpointing", value=False)
|
||||
with gr.Row():
|
||||
full_fp16_input = gr.Checkbox(
|
||||
label="Full fp16 training (experimental)", value=False)
|
||||
no_token_padding_input = gr.Checkbox(label="No token padding",
|
||||
value=False)
|
||||
use_safetensors_input = gr.Checkbox(
|
||||
label="Use safetensor when saving", value=False)
|
||||
|
||||
gradient_checkpointing_input = gr.Checkbox(
|
||||
label="Gradient checkpointing", value=False)
|
||||
with gr.Row():
|
||||
enable_bucket_input = gr.Checkbox(label="Enable buckets",
|
||||
value=True)
|
||||
cache_latent_input = gr.Checkbox(label="Cache latent", value=True)
|
||||
use_8bit_adam_input = gr.Checkbox(label="Use 8bit adam",
|
||||
value=True)
|
||||
xformers_input = gr.Checkbox(label="Use xformers", value=True)
|
||||
|
||||
with gr.Tab("Model conversion"):
|
||||
convert_to_safetensors_input = gr.Checkbox(
|
||||
label="Convert to SafeTensors", value=False)
|
||||
label="Convert to SafeTensors", value=True)
|
||||
convert_to_ckpt_input = gr.Checkbox(label="Convert to CKPT",
|
||||
value=False)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user