Update UI code
This commit is contained in:
parent
a7cd798abf
commit
dddcffbe51
@ -244,6 +244,27 @@ def train_model(
|
|||||||
f"{output_dir}/last.yaml",
|
f"{output_dir}/last.yaml",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if pretrained_model_name_or_path == "":
|
||||||
|
msgbox("Source model information is missing")
|
||||||
|
return
|
||||||
|
|
||||||
|
if train_data_dir == "":
|
||||||
|
msgbox("Image folder path is missing")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(train_data_dir):
|
||||||
|
msgbox("Image folder does not exist")
|
||||||
|
return
|
||||||
|
|
||||||
|
if reg_data_dir != "":
|
||||||
|
if not os.path.exists(reg_data_dir):
|
||||||
|
msgbox("Regularisation folder does not exist")
|
||||||
|
return
|
||||||
|
|
||||||
|
if output_dir == "":
|
||||||
|
msgbox("Output folder path is missing")
|
||||||
|
return
|
||||||
|
|
||||||
# Get a list of all subfolders in train_data_dir
|
# Get a list of all subfolders in train_data_dir
|
||||||
subfolders = [
|
subfolders = [
|
||||||
f for f in os.listdir(train_data_dir)
|
f for f in os.listdir(train_data_dir)
|
||||||
@ -701,7 +722,7 @@ with interface:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
seed_input = gr.Textbox(label="Seed", value=1234)
|
seed_input = gr.Textbox(label="Seed", value=1234)
|
||||||
max_resolution_input = gr.Textbox(label="Max resolution",
|
max_resolution_input = gr.Textbox(label="Max resolution",
|
||||||
value="512,512")
|
value="512,512", placeholder="512,512")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
caption_extention_input = gr.Textbox(
|
caption_extention_input = gr.Textbox(
|
||||||
label="Caption Extension",
|
label="Caption Extension",
|
||||||
@ -715,26 +736,27 @@ with interface:
|
|||||||
step=1,
|
step=1,
|
||||||
label="Stop text encoder training",
|
label="Stop text encoder training",
|
||||||
)
|
)
|
||||||
with gr.Row():
|
|
||||||
use_safetensors_input = gr.Checkbox(
|
|
||||||
label="Use safetensor when saving", value=False)
|
|
||||||
enable_bucket_input = gr.Checkbox(label="Enable buckets",
|
|
||||||
value=True)
|
|
||||||
cache_latent_input = gr.Checkbox(label="Cache latent", value=True)
|
|
||||||
gradient_checkpointing_input = gr.Checkbox(
|
|
||||||
label="Gradient checkpointing", value=False)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
full_fp16_input = gr.Checkbox(
|
full_fp16_input = gr.Checkbox(
|
||||||
label="Full fp16 training (experimental)", value=False)
|
label="Full fp16 training (experimental)", value=False)
|
||||||
no_token_padding_input = gr.Checkbox(label="No token padding",
|
no_token_padding_input = gr.Checkbox(label="No token padding",
|
||||||
value=False)
|
value=False)
|
||||||
|
use_safetensors_input = gr.Checkbox(
|
||||||
|
label="Use safetensor when saving", value=False)
|
||||||
|
|
||||||
|
gradient_checkpointing_input = gr.Checkbox(
|
||||||
|
label="Gradient checkpointing", value=False)
|
||||||
|
with gr.Row():
|
||||||
|
enable_bucket_input = gr.Checkbox(label="Enable buckets",
|
||||||
|
value=True)
|
||||||
|
cache_latent_input = gr.Checkbox(label="Cache latent", value=True)
|
||||||
use_8bit_adam_input = gr.Checkbox(label="Use 8bit adam",
|
use_8bit_adam_input = gr.Checkbox(label="Use 8bit adam",
|
||||||
value=True)
|
value=True)
|
||||||
xformers_input = gr.Checkbox(label="Use xformers", value=True)
|
xformers_input = gr.Checkbox(label="Use xformers", value=True)
|
||||||
|
|
||||||
with gr.Tab("Model conversion"):
|
with gr.Tab("Model conversion"):
|
||||||
convert_to_safetensors_input = gr.Checkbox(
|
convert_to_safetensors_input = gr.Checkbox(
|
||||||
label="Convert to SafeTensors", value=False)
|
label="Convert to SafeTensors", value=True)
|
||||||
convert_to_ckpt_input = gr.Checkbox(label="Convert to CKPT",
|
convert_to_ckpt_input = gr.Checkbox(label="Convert to CKPT",
|
||||||
value=False)
|
value=False)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user