remove the need to place configs near models

This commit is contained in:
AUTOMATIC 2023-01-27 11:28:12 +03:00
parent 7a14c8ab45
commit d2ac95fa7b
10 changed files with 361 additions and 152 deletions

View File

@ -0,0 +1,99 @@
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
model:
base_learning_rate: 1.0e-04
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: edited
cond_stage_key: edit
# image_size: 64
# image_size: 32
image_size: 16
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: true
load_ema: true
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 128
num_workers: 1
wrap: false
validation:
target: edit_dataset.EditDataset
params:
path: data/clip-filtered-dataset
cache_dir: data/
cache_name: data_10k
split: val
min_text_sim: 0.2
min_image_sim: 0.75
min_direction_sim: 0.2
max_samples_per_prompt: 1
min_resize_res: 512
max_resize_res: 512
crop_res: 512
output_as_edit: False
real_input: True

View File

@ -1,8 +1,7 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentDiffusion target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params: params:
parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
num_timesteps_cond: 1 num_timesteps_cond: 1
@ -12,29 +11,36 @@ model:
cond_stage_key: "txt" cond_stage_key: "txt"
image_size: 64 image_size: 64
channels: 4 channels: 4
cond_stage_trainable: false cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn conditioning_key: hybrid # important
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused image_size: 32 # unused
in_channels: 4 in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 4, 2, 1 ] attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn num_heads: 8
use_spatial_transformer: True use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 1024 context_dim: 768
use_checkpoint: True
legacy: False legacy: False
first_stage_config: first_stage_config:
@ -43,7 +49,6 @@ model:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
#attn_type: "vanilla-xformers"
double_z: true double_z: true
z_channels: 4 z_channels: 4
resolution: 256 resolution: 256
@ -62,7 +67,4 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.sd_models import checkpoints_list
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import List from typing import List
@ -387,7 +388,7 @@ class Api:
] ]
def get_sd_models(self): def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self): def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

View File

@ -34,14 +34,18 @@ def get_cuda_device_string():
return "cuda" return "cuda"
def get_optimal_device(): def get_optimal_device_name():
if torch.cuda.is_available(): if torch.cuda.is_available():
return torch.device(get_cuda_device_string()) return get_cuda_device_string()
if has_mps(): if has_mps():
return torch.device("mps") return "mps"
return cpu return "cpu"
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task): def get_device_for(task):

View File

@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0, e_t return x_prev, pred_x0, e_t
def should_hijack_inpainting(checkpoint_info):
from modules import sd_models
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
def do_inpainting_hijack(): def do_inpainting_hijack():
# p_sample_plms is needed because PLMS can't work with dicts as conditionings # p_sample_plms is needed because PLMS can't work with dicts as conditionings

View File

@ -2,8 +2,6 @@ import collections
import os.path import os.path
import sys import sys
import gc import gc
import time
from collections import namedtuple
import torch import torch
import re import re
import safetensors.torch import safetensors.torch
@ -14,10 +12,10 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.sd_hijack_ip2p import should_hijack_ip2p from modules.timer import Timer
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
@ -99,17 +97,6 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
def find_checkpoint_config(info):
if info is None:
return shared.cmd_opts.config
config = os.path.splitext(info.filename)[0] + ".yaml"
if os.path.exists(config):
return config
return shared.cmd_opts.config
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
checkpoint_alisases.clear() checkpoint_alisases.clear()
@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
if device is None:
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else: else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@ -229,32 +214,44 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd return sd
def load_model_weights(model, checkpoint_info: CheckpointInfo): def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if checkpoint_info in checkpoints_loaded:
# use checkpoint cache
print(f"Loading weights [{sd_model_hash}] from cache")
return checkpoints_loaded[checkpoint_info]
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
res = read_state_dict(checkpoint_info.filename)
timer.record("load weights from disk")
return res
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
title = checkpoint_info.title title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash() sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if checkpoint_info.title != title: if checkpoint_info.title != title:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
cache_enabled = shared.opts.sd_checkpoint_cache > 0 if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
if cache_enabled and checkpoint_info in checkpoints_loaded: model.load_state_dict(state_dict, strict=False)
# use checkpoint cache del state_dict
print(f"Loading weights [{sd_model_hash}] from cache") timer.record("apply weights to model")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
sd = read_state_dict(checkpoint_info.filename) if shared.opts.sd_checkpoint_cache > 0:
model.load_state_dict(sd, strict=False)
del sd
if cache_enabled:
# cache newly loaded model # cache newly loaded model
checkpoints_loaded[checkpoint_info] = model.state_dict().copy() checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
if shared.cmd_opts.opt_channelslast: if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
vae = model.first_stage_model vae = model.first_stage_model
@ -272,17 +269,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
if depth_model: if depth_model:
model.depth_model = depth_model model.depth_model = depth_model
timer.record("apply half()")
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
devices.dtype_unet = model.model.diffusion_model.dtype devices.dtype_unet = model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
timer.record("apply dtype to VAE")
# clean up cache if limit is reached # clean up cache if limit is reached
if cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model checkpoints_loaded.popitem(last=False)
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_info.filename model.sd_model_checkpoint = checkpoint_info.filename
@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
sd_vae.clear_loaded_vae() sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
sd_vae.load_vae(model, vae_file, vae_source) sd_vae.load_vae(model, vae_file, vae_source)
timer.record("load VAE")
def enable_midas_autodownload(): def enable_midas_autodownload():
@ -340,24 +340,20 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper midas.api.load_model = load_model_wrapper
class Timer: def repair_config(sd_config):
def __init__(self):
self.start = time.time()
def elapsed(self): if not hasattr(sd_config.model.params, "use_ema"):
end = time.time() sd_config.model.params.use_ema = False
res = end - self.start
self.start = end if shared.cmd_opts.no_half:
return res sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
checkpoint_config = find_checkpoint_config(checkpoint_info)
if checkpoint_config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_config}")
if shared.sd_model: if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@ -365,38 +361,27 @@ def load_model(checkpoint_info=None):
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None
if should_hijack_ip2p(checkpoint_info):
sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion"
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.first_stage_key = "edited"
sd_config.model.params.cond_stage_key = "edit"
sd_config.model.params.image_size = 16
sd_config.model.params.unet_config.params.in_channels = 8
sd_config.model.params.unet_config.params.out_channels = 4
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
do_inpainting_hijack() do_inpainting_hijack()
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True
timer = Timer() timer = Timer()
sd_model = None if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
timer.record("find config")
sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config)
timer.record("load config")
print(f"Creating model from config: {checkpoint_config}")
sd_model = None
try: try:
with sd_disable_initialization.DisableInitialization(): with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
@ -407,29 +392,35 @@ def load_model(checkpoint_info=None):
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
elapsed_create = timer.elapsed() sd_model.used_config = checkpoint_config
load_model_weights(sd_model, checkpoint_info) timer.record("create model")
elapsed_load_weights = timer.elapsed() load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else: else:
sd_model.to(shared.device) sd_model.to(shared.device)
timer.record("move model to device")
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model shared.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
timer.record("load textual inversion embeddings")
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
elapsed_the_rest = timer.elapsed() timer.record("scripts callbacks")
print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") print(f"Model loaded in {timer.summary()}.")
return sd_model return sd_model
@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
if sd_model is None: # previous model load failed if sd_model is None: # previous model load failed
current_checkpoint_info = None current_checkpoint_info = None
else: else:
@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
else: else:
@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None):
timer = Timer() timer = Timer()
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
timer.record("find config")
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
return shared.sd_model
try: try:
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
except Exception as e: except Exception as e:
print("Failed to load checkpoint, restoring previous") print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info) load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise raise
finally: finally:
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
timer.record("move model to device")
elapsed = timer.elapsed() print(f"Weights loaded in {timer.summary()}.")
print(f"Weights loaded in {elapsed:.1f}s.")
return sd_model return sd_model

View File

@ -0,0 +1,65 @@
import re
import os
from modules import shared, paths
sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
re_parametrization_v = re.compile(r'-v\b')
def guess_model_config_from_state_dict(sd, filename):
fn = os.path.basename(filename)
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None)
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if re.search(re_parametrization_v, fn) or "v2-1_768" in fn:
return config_sd2v
else:
return config_sd2
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
return config_inpainting
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
if roberta_weight is not None:
return config_alt_diffusion
return config_default
def find_checkpoint_config(state_dict, info):
if info is None:
return guess_model_config_from_state_dict(state_dict, "")
config = find_checkpoint_config_near_filename(info)
if config is not None:
return config
return guess_model_config_from_state_dict(state_dict, info.filename)
def find_checkpoint_config_near_filename(info):
if info is None:
return None
config = os.path.splitext(info.filename)[0] + ".yaml"
if os.path.exists(config):
return config
return None

View File

@ -13,13 +13,14 @@ import modules.interrogate
import modules.memmon import modules.memmon
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items from modules import localization, extensions, script_loading, errors, ui_components, shared_items
from modules.paths import models_path, script_path from modules.paths import models_path, script_path
demo = None demo = None
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),

View File

@ -4,7 +4,20 @@ def realesrgan_models_names():
import modules.realesrgan_model import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
def postprocessing_scripts(): def postprocessing_scripts():
import modules.scripts import modules.scripts
return modules.scripts.scripts_postproc.scripts return modules.scripts.scripts_postproc.scripts
def sd_vae_items():
import modules.sd_vae
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
def refresh_vae_list():
import modules.sd_vae
return modules.sd_vae.refresh_vae_list

35
modules/timer.py Normal file
View File

@ -0,0 +1,35 @@
import time
class Timer:
def __init__(self):
self.start = time.time()
self.records = {}
self.total = 0
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def record(self, category, extra_time=0):
e = self.elapsed()
if category not in self.records:
self.records[category] = 0
self.records[category] += e + extra_time
self.total += e + extra_time
def summary(self):
res = f"{self.total:.1f}s"
additions = [x for x in self.records.items() if x[1] >= 0.1]
if not additions:
return res
res += " ("
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
res += ")"
return res