update ui with extra training options
This commit is contained in:
parent
eb7ba4b713
commit
4d663055de
@ -1206,6 +1206,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
new_embedding_name = gr.Textbox(label="Name")
|
new_embedding_name = gr.Textbox(label="Name")
|
||||||
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||||
|
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
|
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
@ -1247,14 +1249,17 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||||
|
|
||||||
with gr.Tab(label="Train"):
|
with gr.Tab(label="Train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images<br/>Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
with gr.Row():
|
||||||
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
|
||||||
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||||
|
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
@ -1288,6 +1293,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
new_embedding_name,
|
new_embedding_name,
|
||||||
initialization_text,
|
initialization_text,
|
||||||
nvpt,
|
nvpt,
|
||||||
|
overwrite_old_embedding,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
@ -1303,6 +1309,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
new_hypernetwork_sizes,
|
new_hypernetwork_sizes,
|
||||||
new_hypernetwork_layer_structure,
|
new_hypernetwork_layer_structure,
|
||||||
new_hypernetwork_add_layer_norm,
|
new_hypernetwork_add_layer_norm,
|
||||||
|
overwrite_old_hypernetwork,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user