add checkpoints tab for extra networks UI
This commit is contained in:
parent
91c8d0dcfc
commit
1d8e06d542
@ -20,7 +20,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
preview = None
|
preview = None
|
||||||
for file in previews:
|
for file in previews:
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
|
preview = self.link_preview(file)
|
||||||
break
|
break
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
|
@ -309,3 +309,10 @@ function updateInput(target){
|
|||||||
Object.defineProperty(e, "target", {value: target})
|
Object.defineProperty(e, "target", {value: target})
|
||||||
target.dispatchEvent(e);
|
target.dispatchEvent(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
var desiredCheckpointName = null;
|
||||||
|
function selectCheckpoint(name){
|
||||||
|
desiredCheckpointName = name;
|
||||||
|
gradioApp().getElementById('change_checkpoint').click()
|
||||||
|
}
|
||||||
|
@ -1560,6 +1560,14 @@ def create_ui():
|
|||||||
outputs=[component, text_settings],
|
outputs=[component, text_settings],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||||
|
button_set_checkpoint.click(
|
||||||
|
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
|
||||||
|
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||||
|
inputs=[component_dict['sd_model_checkpoint'], dummy_component],
|
||||||
|
outputs=[component_dict['sd_model_checkpoint'], text_settings],
|
||||||
|
)
|
||||||
|
|
||||||
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
||||||
|
|
||||||
def get_settings_values():
|
def get_settings_values():
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import os.path
|
import os.path
|
||||||
|
import urllib.parse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -8,12 +10,31 @@ import html
|
|||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
|
||||||
extra_pages = []
|
extra_pages = []
|
||||||
|
allowed_dirs = set()
|
||||||
|
|
||||||
|
|
||||||
def register_page(page):
|
def register_page(page):
|
||||||
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
||||||
|
|
||||||
extra_pages.append(page)
|
extra_pages.append(page)
|
||||||
|
allowed_dirs.clear()
|
||||||
|
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
||||||
|
|
||||||
|
|
||||||
|
def add_pages_to_demo(app):
|
||||||
|
def fetch_file(filename: str = ""):
|
||||||
|
from starlette.responses import FileResponse
|
||||||
|
|
||||||
|
if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
|
||||||
|
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||||
|
|
||||||
|
if os.path.splitext(filename)[1].lower() != ".png":
|
||||||
|
raise ValueError(f"File cannot be fetched: {filename}. Only png.")
|
||||||
|
|
||||||
|
# would profit from returning 304
|
||||||
|
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||||
|
|
||||||
|
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPage:
|
class ExtraNetworksPage:
|
||||||
@ -26,6 +47,9 @@ class ExtraNetworksPage:
|
|||||||
def refresh(self):
|
def refresh(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def link_preview(self, filename):
|
||||||
|
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
|
||||||
|
|
||||||
def create_html(self, tabname):
|
def create_html(self, tabname):
|
||||||
view = shared.opts.extra_networks_default_view
|
view = shared.opts.extra_networks_default_view
|
||||||
items_html = ''
|
items_html = ''
|
||||||
@ -54,13 +78,17 @@ class ExtraNetworksPage:
|
|||||||
def create_html_for_item(self, item, tabname):
|
def create_html_for_item(self, item, tabname):
|
||||||
preview = item.get("preview", None)
|
preview = item.get("preview", None)
|
||||||
|
|
||||||
|
onclick = item.get("onclick", None)
|
||||||
|
if onclick is None:
|
||||||
|
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
||||||
"prompt": item["prompt"],
|
"prompt": item.get("prompt", None),
|
||||||
"tabname": json.dumps(tabname),
|
"tabname": json.dumps(tabname),
|
||||||
"local_preview": json.dumps(item["local_preview"]),
|
"local_preview": json.dumps(item["local_preview"]),
|
||||||
"name": item["name"],
|
"name": item["name"],
|
||||||
"card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
|
"card_clicked": onclick,
|
||||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,7 +171,7 @@ def path_is_parent(parent_path, child_path):
|
|||||||
parent_path = os.path.abspath(parent_path)
|
parent_path = os.path.abspath(parent_path)
|
||||||
child_path = os.path.abspath(child_path)
|
child_path = os.path.abspath(child_path)
|
||||||
|
|
||||||
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
|
return child_path.startswith(parent_path)
|
||||||
|
|
||||||
|
|
||||||
def setup_ui(ui, gallery):
|
def setup_ui(ui, gallery):
|
||||||
@ -173,7 +201,8 @@ def setup_ui(ui, gallery):
|
|||||||
|
|
||||||
ui.button_save_preview.click(
|
ui.button_save_preview.click(
|
||||||
fn=save_preview,
|
fn=save_preview,
|
||||||
_js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
|
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
||||||
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
||||||
outputs=[*ui.pages]
|
outputs=[*ui.pages]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
38
modules/ui_extra_networks_checkpoints.py
Normal file
38
modules/ui_extra_networks_checkpoints.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import html
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
from modules import shared, ui_extra_networks, sd_models
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('Checkpoints')
|
||||||
|
|
||||||
|
def refresh(self):
|
||||||
|
shared.refresh_checkpoints()
|
||||||
|
|
||||||
|
def list_items(self):
|
||||||
|
for name, checkpoint1 in sd_models.checkpoints_list.items():
|
||||||
|
checkpoint: sd_models.CheckpointInfo = checkpoint1
|
||||||
|
path, ext = os.path.splitext(checkpoint.filename)
|
||||||
|
previews = [path + ".png", path + ".preview.png"]
|
||||||
|
|
||||||
|
preview = None
|
||||||
|
for file in previews:
|
||||||
|
if os.path.isfile(file):
|
||||||
|
preview = self.link_preview(file)
|
||||||
|
break
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"name": checkpoint.model_name,
|
||||||
|
"filename": path,
|
||||||
|
"preview": preview,
|
||||||
|
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
||||||
|
"local_preview": path + ".png",
|
||||||
|
}
|
||||||
|
|
||||||
|
def allowed_directories_for_previews(self):
|
||||||
|
return [shared.cmd_opts.ckpt_dir, sd_models.model_path]
|
||||||
|
|
@ -19,7 +19,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
preview = None
|
preview = None
|
||||||
for file in previews:
|
for file in previews:
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
|
preview = self.link_preview(file)
|
||||||
break
|
break
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
|
@ -19,7 +19,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
|
|
||||||
preview = None
|
preview = None
|
||||||
if os.path.isfile(preview_file):
|
if os.path.isfile(preview_file):
|
||||||
preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
|
preview = self.link_preview(preview_file)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": embedding.name,
|
"name": embedding.name,
|
||||||
|
6
webui.py
6
webui.py
@ -12,7 +12,7 @@ from packaging import version
|
|||||||
import logging
|
import logging
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
from modules import import_hook, errors, extra_networks
|
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
|
||||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
||||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||||
|
|
||||||
@ -119,6 +119,7 @@ def initialize():
|
|||||||
ui_extra_networks.intialize()
|
ui_extra_networks.intialize()
|
||||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||||
|
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
||||||
|
|
||||||
extra_networks.initialize()
|
extra_networks.initialize()
|
||||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||||
@ -227,6 +228,8 @@ def webui():
|
|||||||
if launch_api:
|
if launch_api:
|
||||||
create_api(app)
|
create_api(app)
|
||||||
|
|
||||||
|
ui_extra_networks.add_pages_to_demo(app)
|
||||||
|
|
||||||
modules.script_callbacks.app_started_callback(shared.demo, app)
|
modules.script_callbacks.app_started_callback(shared.demo, app)
|
||||||
|
|
||||||
wait_on_server(shared.demo)
|
wait_on_server(shared.demo)
|
||||||
@ -254,6 +257,7 @@ def webui():
|
|||||||
ui_extra_networks.intialize()
|
ui_extra_networks.intialize()
|
||||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||||
|
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
||||||
|
|
||||||
extra_networks.initialize()
|
extra_networks.initialize()
|
||||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||||
|
Loading…
Reference in New Issue
Block a user