diff --git a/README.md b/README.md index 33508f31..ba9f3952 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - API - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) +- Can use safetensors to safely load model files without python pickle ## Where are Aesthetic Gradients?!?! Aesthetic Gradients are now an extension. You can install it using git: diff --git a/modules/modelloader.py b/modules/modelloader.py index e4a6f8ac..7d2f0ade 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -82,6 +82,7 @@ def cleanup_models(): src_path = models_path dest_path = os.path.join(models_path, "Stable-diffusion") move_files(src_path, dest_path, ".ckpt") + move_files(src_path, dest_path, ".safetensors") src_path = os.path.join(root_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN") move_files(src_path, dest_path) diff --git a/modules/sd_models.py b/modules/sd_models.py index c59151e0..4ccdf30b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,6 +4,7 @@ import sys import gc from collections import namedtuple import torch +from safetensors.torch import load_file import re from omegaconf import OmegaConf @@ -16,9 +17,10 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config', 'exttype']) checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() +checkpoint_types = {'.ckpt':'pickle','.safetensors':'safetensors'} try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -45,7 +47,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt",".safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -60,15 +62,15 @@ def list_models(): if name.startswith("\\") or name.startswith("/"): name = name[1:] - shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + shortname, ext = os.path.splitext(name.replace("/", "_").replace("\\", "_")) - return f'{name} [{shorthash}]', shortname + return f'{name} [{checkpoint_types[ext]}] [{shorthash}]', shortname cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config, '') shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) @@ -76,12 +78,12 @@ def list_models(): h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - basename, _ = os.path.splitext(filename) + basename, ext = os.path.splitext(filename) config = basename + ".yaml" if not os.path.exists(config): config = shared.cmd_opts.config - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config, ext) def get_closet_checkpoint_match(searchString): @@ -173,7 +175,13 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if(checkpoint_types[checkpoint_info.exttype] == 'safetensors'): + # safely load weights + # TODO: safetensors supports zero copy fast load to gpu, see issue #684 + pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") diff --git a/requirements.txt b/requirements.txt index 762db4f3..f7de9f70 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ kornia lark inflection GitPython +safetensors