diff --git a/modules/sd_models.py b/modules/sd_models.py index 3a01c93d..3aa21ec1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,4 +1,4 @@ -import glob +import collections import os.path import sys from collections import namedtuple @@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir)) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) checkpoints_list = {} +checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -132,41 +133,45 @@ def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") + if checkpoint_info not in checkpoints_loaded: + print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") + sd = get_state_dict_from_checkpoint(pl_sd) + model.load_state_dict(sd, strict=False) - sd = get_state_dict_from_checkpoint(pl_sd) + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) - model.load_state_dict(sd, strict=False) + if not shared.cmd_opts.no_half: + model.half() - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - if not shared.cmd_opts.no_half: - model.half() + vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: + vae_file = shared.cmd_opts.vae_path - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + if os.path.exists(vae_file): + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + model.first_stage_model.load_state_dict(vae_dict) - if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: - vae_file = shared.cmd_opts.vae_path + model.first_stage_model.to(devices.dtype_vae) - if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") - - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} - - model.first_stage_model.load_state_dict(vae_dict) - - model.first_stage_model.to(devices.dtype_vae) + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU + else: + print(f"Loading weights [{sd_model_hash}] from cache") + checkpoints_loaded.move_to_end(checkpoint_info) + model.load_state_dict(checkpoints_loaded[checkpoint_info]) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file @@ -205,6 +210,7 @@ def reload_model_weights(sd_model, info=None): return if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + checkpoints_loaded.clear() shared.sd_model = load_model() return shared.sd_model diff --git a/modules/shared.py b/modules/shared.py index d41a7ab3..aa69bedf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -242,6 +242,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),