Make selected tab configurable with UI config

This commit is contained in:
space-nuko 2023-03-29 13:04:02 -05:00
parent 22bcc7be42
commit 67955ca9e5
4 changed files with 34 additions and 16 deletions

View File

@ -94,6 +94,9 @@ def send_gradio_gallery_to_image(x):
def visit(x, func, path=""): def visit(x, func, path=""):
if hasattr(x, 'children'): if hasattr(x, 'children'):
if isinstance(x, gr.Tabs) and x.elem_id is not None:
# Tabs element can't have a label, have to use elem_id instead
func(f"{path}/Tabs@{x.elem_id}", x)
for c in x.children: for c in x.children:
visit(c, func, path) visit(c, func, path)
elif x.label is not None: elif x.label is not None:
@ -1048,7 +1051,7 @@ def create_ui():
with gr.Row(variant="compact").style(equal_height=False): with gr.Row(variant="compact").style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"): with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"): with gr.Tab(label="Create embedding", id="create_embedding"):
new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
@ -1061,7 +1064,7 @@ def create_ui():
with gr.Column(): with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
with gr.Tab(label="Create hypernetwork"): with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
@ -1079,7 +1082,7 @@ def create_ui():
with gr.Column(): with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
with gr.Tab(label="Preprocess images"): with gr.Tab(label="Preprocess images", id="preprocess_images"):
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
@ -1146,7 +1149,7 @@ def create_ui():
def get_textual_inversion_template_names(): def get_textual_inversion_template_names():
return sorted([x for x in textual_inversion.textual_inversion_templates]) return sorted([x for x in textual_inversion.textual_inversion_templates])
with gr.Tab(label="Train"): with gr.Tab(label="Train", id="train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with FormRow(): with FormRow():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
@ -1479,7 +1482,7 @@ def create_ui():
current_row.__exit__() current_row.__exit__()
current_tab.__exit__() current_tab.__exit__()
with gr.TabItem("Actions"): with gr.TabItem("Actions", id="actions"):
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization") download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
@ -1487,7 +1490,7 @@ def create_ui():
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
with gr.TabItem("Licenses"): with gr.TabItem("Licenses", id="licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages") gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
@ -1735,12 +1738,27 @@ def create_ui():
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
def check_tab_id(tab_id):
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
if type(tab_id) == str:
tab_ids = [t.id for t in tab_items]
return tab_id in tab_ids
elif type(tab_id) == int:
return tab_id >= 0 and tab_id < len(tab_items)
else:
return False
if type(x) == gr.Tabs:
apply_field(x, 'selected', check_tab_id)
visit(txt2img_interface, loadsave, "txt2img") visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img") visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras") visit(extras_interface, loadsave, "extras")
visit(modelmerger_interface, loadsave, "modelmerger") visit(modelmerger_interface, loadsave, "modelmerger")
visit(train_interface, loadsave, "train") visit(train_interface, loadsave, "train")
loadsave(f"webui/Tabs@{tabs.elem_id}", tabs)
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
with open(ui_config_file, "w", encoding="utf8") as file: with open(ui_config_file, "w", encoding="utf8") as file:
json.dump(ui_settings, file, indent=4) json.dump(ui_settings, file, indent=4)

View File

@ -294,7 +294,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as ui: with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs: with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"): with gr.TabItem("Installed", id="installed"):
with gr.Row(elem_id="extensions_installed_top"): with gr.Row(elem_id="extensions_installed_top"):
apply = gr.Button(value="Apply and restart UI", variant="primary") apply = gr.Button(value="Apply and restart UI", variant="primary")
@ -327,7 +327,7 @@ def create_ui():
outputs=[extensions_table, info], outputs=[extensions_table, info],
) )
with gr.TabItem("Available"): with gr.TabItem("Available", id="available"):
with gr.Row(): with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False) available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
@ -374,7 +374,7 @@ def create_ui():
outputs=[available_extensions_table, install_result] outputs=[available_extensions_table, install_result]
) )
with gr.TabItem("Install from URL"): with gr.TabItem("Install from URL", id="install_from_url"):
install_url = gr.Text(label="URL for extension's git repository") install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
install_button = gr.Button(value="Install", variant="primary") install_button = gr.Button(value="Install", variant="primary")

View File

@ -241,9 +241,9 @@ def create_ui(container, button, tabname):
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
for page in ui.stored_extra_pages: for page in ui.stored_extra_pages:
with gr.Tab(page.title): with gr.Tab(page.title, id=page.title.lower().replace(" ", "_")):
page_elem = gr.HTML(page.create_html(ui.tabname)) page_elem = gr.HTML("")
ui.pages.append(page_elem) ui.pages.append(page_elem)
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
@ -284,7 +284,7 @@ def setup_ui(ui, gallery):
def save_preview(index, images, filename): def save_preview(index, images, filename):
if len(images) == 0: if len(images) == 0:
print("There is no image in gallery to save as a preview.") print("There is no image in gallery to save as a preview.")
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] return ["" for page in ui.stored_extra_pages]
index = int(index) index = int(index)
index = 0 if index < 0 else index index = 0 if index < 0 else index
@ -309,7 +309,7 @@ def setup_ui(ui, gallery):
else: else:
image.save(filename) image.save(filename)
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] return ["" for page in ui.stored_extra_pages]
ui.button_save_preview.click( ui.button_save_preview.click(
fn=save_preview, fn=save_preview,

View File

@ -9,13 +9,13 @@ def create_ui():
with gr.Row().style(equal_height=False, variant='compact'): with gr.Row().style(equal_height=False, variant='compact'):
with gr.Column(variant='compact'): with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"): with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single: with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")