Fix steps issue when regularisation images are used

This commit is contained in:
bmaltais 2022-12-13 17:52:25 -05:00
parent 416912eee1
commit 71765ae243

View File

@ -1,3 +1,5 @@
# v1: initial release
import gradio as gr
import json
import math
@ -177,9 +179,15 @@ def train_model(
# Print the result
# print(f"{total_steps} total steps")
if reg_data_dir == "":
reg_factor = 1
else:
print("Regularisation images are used... Will double the number of steps required...")
reg_factor = 2
# calculate max_train_steps
max_train_steps = int(
math.ceil(float(total_steps) / int(train_batch_size) * int(epoch))
math.ceil(float(total_steps) / int(train_batch_size) * int(epoch) * int(reg_factor))
)
print(f"max_train_steps = {max_train_steps}")
@ -332,16 +340,16 @@ with interface:
with gr.Tab("Directories"):
with gr.Row():
train_data_dir_input = gr.Textbox(
label="Image folder", placeholder="directory where the training folders containing the images are located"
label="Image folder", placeholder="Directory where the training folders containing the images are located"
)
reg_data_dir_input = gr.Textbox(
label="Regularisation folder", placeholder="directory where where the regularization folders containing the images are located"
label="Regularisation folder", placeholder="(Optional) Directory where where the regularization folders containing the images are located"
)
with gr.Row():
output_dir_input = gr.Textbox(
label="Output directory",
placeholder="directory to output trained model",
placeholder="Directory to output trained model",
)
logging_dir_input = gr.Textbox(
label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory"