From 21ee46eea791d83b3b49cedd2306c7f0f1807250 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 19 May 2023 15:35:16 +0300 Subject: [PATCH] Deduplicate default extra network registration --- modules/extra_networks.py | 5 +++++ modules/ui_extra_networks.py | 9 +++++++++ webui.py | 16 ++++++---------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/modules/extra_networks.py b/modules/extra_networks.py index f9db41bc..94347275 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -14,6 +14,11 @@ def register_extra_network(extra_network): extra_network_registry[extra_network.name] = extra_network +def register_default_extra_networks(): + from modules.extra_networks_hypernet import ExtraNetworkHypernet + register_extra_network(ExtraNetworkHypernet()) + + class ExtraNetworkParams: def __init__(self, items=None): self.items = items or [] diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 24eeef0e..19fbaae5 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -236,6 +236,15 @@ def initialize(): extra_pages.clear() +def register_default_pages(): + from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion + from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks + from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints + register_page(ExtraNetworksPageTextualInversion()) + register_page(ExtraNetworksPageHypernetworks()) + register_page(ExtraNetworksPageCheckpoints()) + + class ExtraNetworksUi: def __init__(self): self.pages = None diff --git a/webui.py b/webui.py index 30e4f239..ad6be239 100644 --- a/webui.py +++ b/webui.py @@ -34,8 +34,7 @@ startup_timer.record("import gradio") import ldm.modules.encoders.modules # noqa: F401 startup_timer.record("import ldm") -from modules import extra_networks, ui_extra_networks_checkpoints -from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion +from modules import extra_networks from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors @@ -214,12 +213,11 @@ def initialize(): startup_timer.record("reload hypernets") ui_extra_networks.initialize() - ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) - ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) - ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + ui_extra_networks.register_default_pages() extra_networks.initialize() - extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + extra_networks.register_default_extra_networks() + startup_timer.record("extra networks") if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: @@ -420,12 +418,10 @@ def webui(): startup_timer.record("reload hypernetworks") ui_extra_networks.initialize() - ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) - ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) - ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + ui_extra_networks.register_default_pages() extra_networks.initialize() - extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + extra_networks.register_default_extra_networks() startup_timer.record("initialize extra networks")