find configs for models at runtime rather than when starting
This commit is contained in:
parent
02d7abf514
commit
8d8a05a3bb
@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
|
|
||||||
|
|
||||||
def should_hijack_inpainting(checkpoint_info):
|
def should_hijack_inpainting(checkpoint_info):
|
||||||
|
from modules import sd_models
|
||||||
|
|
||||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||||
cfg_basename = os.path.basename(checkpoint_info.config).lower()
|
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||||
|
|
||||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
|
|||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
|
||||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
@ -48,6 +48,14 @@ def checkpoint_tiles():
|
|||||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
||||||
|
|
||||||
|
|
||||||
|
def find_checkpoint_config(info):
|
||||||
|
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||||
|
if os.path.exists(config):
|
||||||
|
return config
|
||||||
|
|
||||||
|
return shared.cmd_opts.config
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
|
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
|
||||||
@ -73,7 +81,7 @@ def list_models():
|
|||||||
if os.path.exists(cmd_ckpt):
|
if os.path.exists(cmd_ckpt):
|
||||||
h = model_hash(cmd_ckpt)
|
h = model_hash(cmd_ckpt)
|
||||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
|
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||||
shared.opts.data['sd_model_checkpoint'] = title
|
shared.opts.data['sd_model_checkpoint'] = title
|
||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
@ -81,12 +89,7 @@ def list_models():
|
|||||||
h = model_hash(filename)
|
h = model_hash(filename)
|
||||||
title, short_model_name = modeltitle(filename, h)
|
title, short_model_name = modeltitle(filename, h)
|
||||||
|
|
||||||
basename, _ = os.path.splitext(filename)
|
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
||||||
config = basename + ".yaml"
|
|
||||||
if not os.path.exists(config):
|
|
||||||
config = shared.cmd_opts.config
|
|
||||||
|
|
||||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(searchString):
|
def get_closet_checkpoint_match(searchString):
|
||||||
@ -282,9 +285,10 @@ def enable_midas_autodownload():
|
|||||||
def load_model(checkpoint_info=None):
|
def load_model(checkpoint_info=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
||||||
|
|
||||||
if checkpoint_info.config != shared.cmd_opts.config:
|
if checkpoint_config != shared.cmd_opts.config:
|
||||||
print(f"Loading config from: {checkpoint_info.config}")
|
print(f"Loading config from: {checkpoint_config}")
|
||||||
|
|
||||||
if shared.sd_model:
|
if shared.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||||
@ -292,7 +296,7 @@ def load_model(checkpoint_info=None):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
sd_config = OmegaConf.load(checkpoint_config)
|
||||||
|
|
||||||
if should_hijack_inpainting(checkpoint_info):
|
if should_hijack_inpainting(checkpoint_info):
|
||||||
# Hardcoded config for now...
|
# Hardcoded config for now...
|
||||||
@ -302,7 +306,7 @@ def load_model(checkpoint_info=None):
|
|||||||
sd_config.model.params.finetune_keys = None
|
sd_config.model.params.finetune_keys = None
|
||||||
|
|
||||||
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
||||||
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml"))
|
||||||
|
|
||||||
if not hasattr(sd_config.model.params, "use_ema"):
|
if not hasattr(sd_config.model.params, "use_ema"):
|
||||||
sd_config.model.params.use_ema = False
|
sd_config.model.params.use_ema = False
|
||||||
@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
sd_model = shared.sd_model
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
|
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
||||||
|
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||||
del sd_model
|
del sd_model
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
load_model(checkpoint_info)
|
load_model(checkpoint_info)
|
||||||
|
Loading…
Reference in New Issue
Block a user