Merge branch 'disable_initialization'

This commit is contained in:
AUTOMATIC 2023-01-10 19:11:47 +03:00
commit 50fb20cedc
3 changed files with 133 additions and 7 deletions

View File

@ -10,7 +10,7 @@ from modules.upscaler import Upscaler
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
""" """
A one-and done loader to try finding the desired models in specified directories. A one-and done loader to try finding the desired models in specified directories.
@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
full_path = file full_path = file
if os.path.isdir(full_path): if os.path.isdir(full_path):
continue continue
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
continue
if len(ext_filter) != 0: if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file) model_name, extension = os.path.splitext(file)
if extension not in ext_filter: if extension not in ext_filter:

View File

@ -0,0 +1,95 @@
import ldm.modules.encoders.modules
import open_clip
import torch
import transformers.utils.hub
class DisableInitialization:
"""
When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working
- changes CLIP and OpenCLIP to not download model weights
- changes CLIP to not make requests to check if there is a new version of a file you already have
When it leaves the block, it reverts everything to how it was before.
Use it like this:
```
with DisableInitialization():
do_things()
```
"""
def __enter__(self):
def do_nothing(*args, **kwargs):
pass
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
# this file is always 404, prevent making request
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json':
raise transformers.utils.hub.EntryNotFoundError
try:
return original(url, *args, local_files_only=True, **kwargs)
except Exception as e:
return original(url, *args, local_files_only=False, **kwargs)
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
self.init_no_grad_normal = torch.nn.init._no_grad_normal_
self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_
self.create_model_and_transforms = open_clip.create_model_and_transforms
self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None)
self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None)
self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None)
self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None)
torch.nn.init.kaiming_uniform_ = do_nothing
torch.nn.init._no_grad_normal_ = do_nothing
torch.nn.init._no_grad_uniform_ = do_nothing
open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
if self.transformers_modeling_utils_load_pretrained_model is not None:
transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model
if self.transformers_tokenization_utils_base_cached_file is not None:
transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file
if self.transformers_configuration_utils_cached_file is not None:
transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file
if self.transformers_utils_hub_get_from_cache is not None:
transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache
def __exit__(self, exc_type, exc_val, exc_tb):
torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_
open_clip.create_model_and_transforms = self.create_model_and_transforms
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
if self.transformers_modeling_utils_load_pretrained_model is not None:
transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model
if self.transformers_tokenization_utils_base_cached_file is not None:
transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file
if self.transformers_configuration_utils_cached_file is not None:
transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file
if self.transformers_utils_hub_get_from_cache is not None:
transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache

View File

@ -2,6 +2,7 @@ import collections
import os.path import os.path
import sys import sys
import gc import gc
import time
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
@ -13,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks, sd_vae from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
@ -61,7 +62,7 @@ def find_checkpoint_config(info):
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
def modeltitle(path, shorthash): def modeltitle(path, shorthash):
abspath = os.path.abspath(path) abspath = os.path.abspath(path)
@ -288,6 +289,17 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper midas.api.load_model = load_model_wrapper
class Timer:
def __init__(self):
self.start = time.time()
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
@ -319,10 +331,21 @@ def load_model(checkpoint_info=None):
if shared.cmd_opts.no_half: if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model) timer = Timer()
try:
with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model)
elapsed_create = timer.elapsed()
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
elapsed_load_weights = timer.elapsed()
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else: else:
@ -337,7 +360,9 @@ def load_model(checkpoint_info=None):
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
print("Model loaded.") elapsed_the_rest = timer.elapsed()
print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
return sd_model return sd_model
@ -348,7 +373,7 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
if sd_model is None: # previous model load failed if sd_model is None: # previous model load failed
current_checkpoint_info = None current_checkpoint_info = None
else: else:
current_checkpoint_info = sd_model.sd_checkpoint_info current_checkpoint_info = sd_model.sd_checkpoint_info
@ -370,6 +395,8 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model)
timer = Timer()
try: try:
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
except Exception as e: except Exception as e:
@ -383,6 +410,8 @@ def reload_model_weights(sd_model=None, info=None):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
print("Weights loaded.") elapsed = timer.elapsed()
print(f"Weights loaded in {elapsed:.1f}s.")
return sd_model return sd_model