Fix steps issue when regularisation images are used
This commit is contained in:
parent
416912eee1
commit
71765ae243
@ -1,3 +1,5 @@
|
|||||||
|
# v1: initial release
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
@ -177,9 +179,15 @@ def train_model(
|
|||||||
# Print the result
|
# Print the result
|
||||||
# print(f"{total_steps} total steps")
|
# print(f"{total_steps} total steps")
|
||||||
|
|
||||||
|
if reg_data_dir == "":
|
||||||
|
reg_factor = 1
|
||||||
|
else:
|
||||||
|
print("Regularisation images are used... Will double the number of steps required...")
|
||||||
|
reg_factor = 2
|
||||||
|
|
||||||
# calculate max_train_steps
|
# calculate max_train_steps
|
||||||
max_train_steps = int(
|
max_train_steps = int(
|
||||||
math.ceil(float(total_steps) / int(train_batch_size) * int(epoch))
|
math.ceil(float(total_steps) / int(train_batch_size) * int(epoch) * int(reg_factor))
|
||||||
)
|
)
|
||||||
print(f"max_train_steps = {max_train_steps}")
|
print(f"max_train_steps = {max_train_steps}")
|
||||||
|
|
||||||
@ -332,16 +340,16 @@ with interface:
|
|||||||
with gr.Tab("Directories"):
|
with gr.Tab("Directories"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_data_dir_input = gr.Textbox(
|
train_data_dir_input = gr.Textbox(
|
||||||
label="Image folder", placeholder="directory where the training folders containing the images are located"
|
label="Image folder", placeholder="Directory where the training folders containing the images are located"
|
||||||
)
|
)
|
||||||
reg_data_dir_input = gr.Textbox(
|
reg_data_dir_input = gr.Textbox(
|
||||||
label="Regularisation folder", placeholder="directory where where the regularization folders containing the images are located"
|
label="Regularisation folder", placeholder="(Optional) Directory where where the regularization folders containing the images are located"
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
output_dir_input = gr.Textbox(
|
output_dir_input = gr.Textbox(
|
||||||
label="Output directory",
|
label="Output directory",
|
||||||
placeholder="directory to output trained model",
|
placeholder="Directory to output trained model",
|
||||||
)
|
)
|
||||||
logging_dir_input = gr.Textbox(
|
logging_dir_input = gr.Textbox(
|
||||||
label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory"
|
label="Logging directory", placeholder="Optional: enable logging and output TensorBoard log to this directory"
|
||||||
|
Loading…
Reference in New Issue
Block a user