Fix steps issue when regularisation images are used

This commit is contained in:
bmaltais 2022-12-13 17:52:25 -05:00
parent 416912eee1
commit 71765ae243

View File

@ -1,3 +1,5 @@
# v1: initial release
import gradio as gr import gradio as gr
import json import json
import math import math
@ -177,9 +179,15 @@ def train_model(
# Print the result # Print the result
# print(f"{total_steps} total steps") # print(f"{total_steps} total steps")
if reg_data_dir == "":
reg_factor = 1
else:
print("Regularisation images are used... Will double the number of steps required...")
reg_factor = 2
# calculate max_train_steps # calculate max_train_steps
max_train_steps = int( max_train_steps = int(
math.ceil(float(total_steps) / int(train_batch_size) * int(epoch)) math.ceil(float(total_steps) / int(train_batch_size) * int(epoch) * int(reg_factor))
) )
print(f"max_train_steps = {max_train_steps}") print(f"max_train_steps = {max_train_steps}")
@ -332,16 +340,16 @@ with interface:
with gr.Tab("Directories"): with gr.Tab("Directories"):
with gr.Row(): with gr.Row():
train_data_dir_input = gr.Textbox( train_data_dir_input = gr.Textbox(
label="Image folder", placeholder="directory where the training folders containing the images are located" label="Image folder", placeholder="Directory where the training folders containing the images are located"
) )
reg_data_dir_input = gr.Textbox( reg_data_dir_input = gr.Textbox(
label="Regularisation folder", placeholder="directory where where the regularization folders containing the images are located" label="Regularisation folder", placeholder="(Optional) Directory where where the regularization folders containing the images are located"
) )
with gr.Row(): with gr.Row():
output_dir_input = gr.Textbox( output_dir_input = gr.Textbox(
label="Output directory", label="Output directory",
placeholder="directory to output trained model", placeholder="Directory to output trained model",
) )
logging_dir_input = gr.Textbox( logging_dir_input = gr.Textbox(
label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory" label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory"