diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 948ba54..926c643 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -244,6 +244,27 @@ def train_model( f"{output_dir}/last.yaml", ) + if pretrained_model_name_or_path == "": + msgbox("Source model information is missing") + return + + if train_data_dir == "": + msgbox("Image folder path is missing") + return + + if not os.path.exists(train_data_dir): + msgbox("Image folder does not exist") + return + + if reg_data_dir != "": + if not os.path.exists(reg_data_dir): + msgbox("Regularisation folder does not exist") + return + + if output_dir == "": + msgbox("Output folder path is missing") + return + # Get a list of all subfolders in train_data_dir subfolders = [ f for f in os.listdir(train_data_dir) @@ -701,7 +722,7 @@ with interface: with gr.Row(): seed_input = gr.Textbox(label="Seed", value=1234) max_resolution_input = gr.Textbox(label="Max resolution", - value="512,512") + value="512,512", placeholder="512,512") with gr.Row(): caption_extention_input = gr.Textbox( label="Caption Extension", @@ -715,26 +736,27 @@ with interface: step=1, label="Stop text encoder training", ) - with gr.Row(): - use_safetensors_input = gr.Checkbox( - label="Use safetensor when saving", value=False) - enable_bucket_input = gr.Checkbox(label="Enable buckets", - value=True) - cache_latent_input = gr.Checkbox(label="Cache latent", value=True) - gradient_checkpointing_input = gr.Checkbox( - label="Gradient checkpointing", value=False) with gr.Row(): full_fp16_input = gr.Checkbox( label="Full fp16 training (experimental)", value=False) no_token_padding_input = gr.Checkbox(label="No token padding", value=False) + use_safetensors_input = gr.Checkbox( + label="Use safetensor when saving", value=False) + + gradient_checkpointing_input = gr.Checkbox( + label="Gradient checkpointing", value=False) + with gr.Row(): + enable_bucket_input = gr.Checkbox(label="Enable buckets", + value=True) + cache_latent_input = gr.Checkbox(label="Cache latent", value=True) use_8bit_adam_input = gr.Checkbox(label="Use 8bit adam", value=True) xformers_input = gr.Checkbox(label="Use xformers", value=True) with gr.Tab("Model conversion"): convert_to_safetensors_input = gr.Checkbox( - label="Convert to SafeTensors", value=False) + label="Convert to SafeTensors", value=True) convert_to_ckpt_input = gr.Checkbox(label="Convert to CKPT", value=False)