helpful error message when trying to load 2.0 without config
failing to load model weights from settings won't break generation for currently loaded model anymore
This commit is contained in:
parent
7e549468b3
commit
02d7abf514
@ -2,9 +2,30 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
def print_error_explanation(message):
|
||||||
|
lines = message.strip().split("\n")
|
||||||
|
max_len = max([len(x) for x in lines])
|
||||||
|
|
||||||
|
print('=' * max_len, file=sys.stderr)
|
||||||
|
for line in lines:
|
||||||
|
print(line, file=sys.stderr)
|
||||||
|
print('=' * max_len, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def display(e: Exception, task):
|
||||||
|
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
message = str(e)
|
||||||
|
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||||
|
print_error_explanation("""
|
||||||
|
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
|
||||||
|
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
def run(code, task):
|
def run(code, task):
|
||||||
try:
|
try:
|
||||||
code()
|
code()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{task}: {type(e).__name__}", file=sys.stderr)
|
display(task, e)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
@ -278,6 +278,7 @@ def enable_midas_autodownload():
|
|||||||
|
|
||||||
midas.api.load_model = load_model_wrapper
|
midas.api.load_model = load_model_wrapper
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None):
|
def load_model(checkpoint_info=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
@ -312,6 +313,7 @@ def load_model(checkpoint_info=None):
|
|||||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||||
|
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
@ -336,10 +338,12 @@ def load_model(checkpoint_info=None):
|
|||||||
def reload_model_weights(sd_model=None, info=None):
|
def reload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import lowvram, devices, sd_hijack
|
||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
|
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
try:
|
||||||
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to load checkpoint, restoring previous")
|
||||||
|
load_model_weights(sd_model, current_checkpoint_info)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
sd_model.to(devices.device)
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
|
||||||
sd_model.to(devices.device)
|
|
||||||
|
|
||||||
print("Weights loaded.")
|
print("Weights loaded.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
@ -14,7 +14,7 @@ import modules.interrogate
|
|||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, sd_vae, extensions, script_loading
|
from modules import localization, sd_vae, extensions, script_loading, errors
|
||||||
from modules.paths import models_path, script_path, sd_path
|
from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
|
|
||||||
@ -494,7 +494,12 @@ class Options:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if self.data_labels[key].onchange is not None:
|
if self.data_labels[key].onchange is not None:
|
||||||
self.data_labels[key].onchange()
|
try:
|
||||||
|
self.data_labels[key].onchange()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"changing setting {key} to {value}")
|
||||||
|
setattr(self, key, oldval)
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
12
webui.py
12
webui.py
@ -9,7 +9,7 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
from modules import import_hook
|
from modules import import_hook, errors
|
||||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
@ -61,7 +61,15 @@ def initialize():
|
|||||||
modelloader.load_upscalers()
|
modelloader.load_upscalers()
|
||||||
|
|
||||||
modules.sd_vae.refresh_vae_list()
|
modules.sd_vae.refresh_vae_list()
|
||||||
modules.sd_models.load_model()
|
|
||||||
|
try:
|
||||||
|
modules.sd_models.load_model()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "loading stable diffusion model")
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user