do not load wait for shared.sd_model to load at startup
This commit is contained in:
parent
696c338ee2
commit
b1717c0a48
@ -2,6 +2,8 @@ import collections
|
|||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import gc
|
import gc
|
||||||
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -404,13 +406,39 @@ def repair_config(sd_config):
|
|||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
|
||||||
|
class SdModelData:
|
||||||
|
def __init__(self):
|
||||||
|
self.sd_model = None
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def get_sd_model(self):
|
||||||
|
if self.sd_model is None:
|
||||||
|
with self.lock:
|
||||||
|
try:
|
||||||
|
load_model()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "loading stable diffusion model")
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("Stable diffusion model failed to load", file=sys.stderr)
|
||||||
|
self.sd_model = None
|
||||||
|
|
||||||
|
return self.sd_model
|
||||||
|
|
||||||
|
def set_sd_model(self, v):
|
||||||
|
self.sd_model = v
|
||||||
|
|
||||||
|
|
||||||
|
model_data = SdModelData()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
if shared.sd_model:
|
if model_data.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||||
shared.sd_model = None
|
model_data.sd_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||||||
timer.record("hijack")
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
shared.sd_model = sd_model
|
model_data.sd_model = sd_model
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
|
|
||||||
@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = model_data.sd_model
|
||||||
|
|
||||||
if sd_model is None: # previous model load failed
|
if sd_model is None: # previous model load failed
|
||||||
current_checkpoint_info = None
|
current_checkpoint_info = None
|
||||||
@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
del sd_model
|
del sd_model
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return shared.sd_model
|
return model_data.sd_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import lowvram, devices, sd_hijack
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
if shared.sd_model:
|
if model_data.sd_model:
|
||||||
|
model_data.sd_model.to(devices.cpu)
|
||||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
model_data.sd_model = None
|
||||||
shared.sd_model.to(devices.cpu)
|
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
|
||||||
shared.sd_model = None
|
|
||||||
sd_model = None
|
sd_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
@ -16,6 +16,7 @@ import modules.styles
|
|||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
|
||||||
@ -600,13 +601,37 @@ class Options:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
opts = Options()
|
opts = Options()
|
||||||
if os.path.exists(config_filename):
|
if os.path.exists(config_filename):
|
||||||
opts.load(config_filename)
|
opts.load(config_filename)
|
||||||
|
|
||||||
|
|
||||||
|
class Shared(sys.modules[__name__].__class__):
|
||||||
|
"""
|
||||||
|
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
||||||
|
at program startup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sd_model_val = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sd_model(self):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
return modules.sd_models.model_data.get_sd_model()
|
||||||
|
|
||||||
|
@sd_model.setter
|
||||||
|
def sd_model(self, value):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
modules.sd_models.model_data.set_sd_model(value)
|
||||||
|
|
||||||
|
|
||||||
|
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
|
||||||
|
sys.modules[__name__].__class__ = Shared
|
||||||
|
|
||||||
settings_components = None
|
settings_components = None
|
||||||
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
|
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
||||||
|
|
||||||
latent_upscale_default_mode = "Latent"
|
latent_upscale_default_mode = "Latent"
|
||||||
latent_upscale_modes = {
|
latent_upscale_modes = {
|
||||||
@ -620,8 +645,6 @@ latent_upscale_modes = {
|
|||||||
|
|
||||||
sd_upscalers = []
|
sd_upscalers = []
|
||||||
|
|
||||||
sd_model = None
|
|
||||||
|
|
||||||
clip_model = None
|
clip_model = None
|
||||||
|
|
||||||
progress_print_out = sys.stdout
|
progress_print_out = sys.stdout
|
||||||
|
@ -828,7 +828,7 @@ def create_ui():
|
|||||||
with FormGroup():
|
with FormGroup():
|
||||||
with FormRow():
|
with FormRow():
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||||
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
|
||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
||||||
|
|
||||||
elif category == "seed":
|
elif category == "seed":
|
||||||
@ -1693,11 +1693,9 @@ def create_ui():
|
|||||||
show_progress=info.refresh is not None,
|
show_progress=info.refresh is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
text_settings.change(
|
update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
||||||
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
|
text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
inputs=[],
|
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
outputs=[image_cfg_scale],
|
|
||||||
)
|
|
||||||
|
|
||||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||||
button_set_checkpoint.click(
|
button_set_checkpoint.click(
|
||||||
|
16
webui.py
16
webui.py
@ -6,6 +6,8 @@ import signal
|
|||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
@ -191,18 +193,10 @@ def initialize():
|
|||||||
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||||
startup_timer.record("refresh textual inversion templates")
|
startup_timer.record("refresh textual inversion templates")
|
||||||
|
|
||||||
try:
|
# load model in parallel to other startup stuff
|
||||||
modules.sd_models.load_model()
|
Thread(target=lambda: shared.sd_model).start()
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, "loading stable diffusion model")
|
|
||||||
print("", file=sys.stderr)
|
|
||||||
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
|
||||||
exit(1)
|
|
||||||
startup_timer.record("load SD checkpoint")
|
|
||||||
|
|
||||||
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
|
||||||
|
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
|
||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
|
Loading…
Reference in New Issue
Block a user