From b1717c0a4804f8ed3bb8cc2f3aea5d095778b447 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 2 May 2023 09:08:00 +0300 Subject: [PATCH] do not load wait for shared.sd_model to load at startup --- modules/sd_models.py | 54 ++++++++++++++++++++++++++++++++------------ modules/shared.py | 31 +++++++++++++++++++++---- modules/ui.py | 10 ++++---- webui.py | 16 ++++--------- 4 files changed, 76 insertions(+), 35 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4f7613a1..59adc7cc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,6 +2,8 @@ import collections import os.path import sys import gc +import threading + import torch import re import safetensors.torch @@ -404,13 +406,39 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' -def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): + +class SdModelData: + def __init__(self): + self.sd_model = None + self.lock = threading.Lock() + + def get_sd_model(self): + if self.sd_model is None: + with self.lock: + try: + load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load", file=sys.stderr) + self.sd_model = None + + return self.sd_model + + def set_sd_model(self, v): + self.sd_model = v + + +model_data = SdModelData() + + +def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - if shared.sd_model: - sd_hijack.model_hijack.undo_hijack(shared.sd_model) - shared.sd_model = None + if model_data.sd_model: + sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + model_data.sd_model = None gc.collect() devices.torch_gc() @@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ timer.record("hijack") sd_model.eval() - shared.sd_model = sd_model + model_data.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None): checkpoint_info = info or select_checkpoint() if not sd_model: - sd_model = shared.sd_model + sd_model = model_data.sd_model if sd_model is None: # previous model load failed current_checkpoint_info = None @@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info, already_loaded_state_dict=state_dict) - return shared.sd_model + return model_data.sd_model try: load_model_weights(sd_model, checkpoint_info, state_dict, timer) @@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None): return sd_model + def unload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack timer = Timer() - if shared.sd_model: - - # shared.sd_model.cond_stage_model.to(devices.cpu) - # shared.sd_model.first_stage_model.to(devices.cpu) - shared.sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(shared.sd_model) - shared.sd_model = None + if model_data.sd_model: + model_data.sd_model.to(devices.cpu) + sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + model_data.sd_model = None sd_model = None gc.collect() devices.torch_gc() diff --git a/modules/shared.py b/modules/shared.py index 6a2b3c2b..151bab9e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,6 +16,7 @@ import modules.styles import modules.devices as devices from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir +from ldm.models.diffusion.ddpm import LatentDiffusion demo = None @@ -600,13 +601,37 @@ class Options: return value - opts = Options() if os.path.exists(config_filename): opts.load(config_filename) + +class Shared(sys.modules[__name__].__class__): + """ + this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than + at program startup. + """ + + sd_model_val = None + + @property + def sd_model(self): + import modules.sd_models + + return modules.sd_models.model_data.get_sd_model() + + @sd_model.setter + def sd_model(self, value): + import modules.sd_models + + modules.sd_models.model_data.set_sd_model(value) + + +sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead +sys.modules[__name__].__class__ = Shared + settings_components = None -"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings""" +"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings""" latent_upscale_default_mode = "Latent" latent_upscale_modes = { @@ -620,8 +645,6 @@ latent_upscale_modes = { sd_upscalers = [] -sd_model = None - clip_model = None progress_print_out = sys.stdout diff --git a/modules/ui.py b/modules/ui.py index 7b45f131..16c46515 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -828,7 +828,7 @@ def create_ui(): with FormGroup(): with FormRow(): cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") + image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") elif category == "seed": @@ -1693,11 +1693,9 @@ def create_ui(): show_progress=info.refresh is not None, ) - text_settings.change( - fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"), - inputs=[], - outputs=[image_cfg_scale], - ) + update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") + text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint.click( diff --git a/webui.py b/webui.py index 357bf4c1..0873a26c 100644 --- a/webui.py +++ b/webui.py @@ -6,6 +6,8 @@ import signal import re import warnings import json +from threading import Thread + from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -191,18 +193,10 @@ def initialize(): modules.textual_inversion.textual_inversion.list_textual_inversion_templates() startup_timer.record("refresh textual inversion templates") - try: - modules.sd_models.load_model() - except Exception as e: - errors.display(e, "loading stable diffusion model") - print("", file=sys.stderr) - print("Stable diffusion model failed to load, exiting", file=sys.stderr) - exit(1) - startup_timer.record("load SD checkpoint") + # load model in parallel to other startup stuff + Thread(target=lambda: shared.sd_model).start() - shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title - - shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)