From 71765ae243ec615efc4eab133336f5d7accb1fb2 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 13 Dec 2022 17:52:25 -0500 Subject: [PATCH] Fix steps issue when regularisation images are used --- dreambooth_gui.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 2300fd0..f2ba27c 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -1,3 +1,5 @@ +# v1: initial release + import gradio as gr import json import math @@ -177,9 +179,15 @@ def train_model( # Print the result # print(f"{total_steps} total steps") + if reg_data_dir == "": + reg_factor = 1 + else: + print("Regularisation images are used... Will double the number of steps required...") + reg_factor = 2 + # calculate max_train_steps max_train_steps = int( - math.ceil(float(total_steps) / int(train_batch_size) * int(epoch)) + math.ceil(float(total_steps) / int(train_batch_size) * int(epoch) * int(reg_factor)) ) print(f"max_train_steps = {max_train_steps}") @@ -332,16 +340,16 @@ with interface: with gr.Tab("Directories"): with gr.Row(): train_data_dir_input = gr.Textbox( - label="Image folder", placeholder="directory where the training folders containing the images are located" + label="Image folder", placeholder="Directory where the training folders containing the images are located" ) reg_data_dir_input = gr.Textbox( - label="Regularisation folder", placeholder="directory where where the regularization folders containing the images are located" + label="Regularisation folder", placeholder="(Optional) Directory where where the regularization folders containing the images are located" ) with gr.Row(): output_dir_input = gr.Textbox( label="Output directory", - placeholder="directory to output trained model", + placeholder="Directory to output trained model", ) logging_dir_input = gr.Textbox( label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory"