Merge remote-tracking branch 'origin/master'

This commit is contained in:
Spaceginner 2023-01-27 17:35:54 +05:00
commit 56c83e453a
No known key found for this signature in database
GPG Key ID: 265D2778149F4183
22 changed files with 520 additions and 188 deletions

View File

@ -0,0 +1,99 @@
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
model:
base_learning_rate: 1.0e-04
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: edited
cond_stage_key: edit
# image_size: 64
# image_size: 32
image_size: 16
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: true
load_ema: true
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 128
num_workers: 1
wrap: false
validation:
target: edit_dataset.EditDataset
params:
path: data/clip-filtered-dataset
cache_dir: data/
cache_name: data_10k
split: val
min_text_sim: 0.2
min_image_sim: 0.75
min_direction_sim: 0.2
max_samples_per_prompt: 1
min_resize_res: 512
max_resize_res: 512
crop_res: 512
output_as_edit: False
real_input: True

View File

@ -1,8 +1,7 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentDiffusion target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params: params:
parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
num_timesteps_cond: 1 num_timesteps_cond: 1
@ -12,29 +11,36 @@ model:
cond_stage_key: "txt" cond_stage_key: "txt"
image_size: 64 image_size: 64
channels: 4 channels: 4
cond_stage_trainable: false cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn conditioning_key: hybrid # important
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused image_size: 32 # unused
in_channels: 4 in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 4, 2, 1 ] attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn num_heads: 8
use_spatial_transformer: True use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 1024 context_dim: 768
use_checkpoint: True
legacy: False legacy: False
first_stage_config: first_stage_config:
@ -43,7 +49,6 @@ model:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
#attn_type: "vanilla-xformers"
double_z: true double_z: true
z_channels: 4 z_channels: 4
resolution: 256 resolution: 256
@ -62,7 +67,4 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.sd_models import checkpoints_list
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import List from typing import List
@ -387,7 +388,7 @@ class Api:
] ]
def get_sd_models(self): def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self): def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

View File

@ -228,7 +228,7 @@ class SDModelItem(BaseModel):
hash: Optional[str] = Field(title="Short hash") hash: Optional[str] = Field(title="Short hash")
sha256: Optional[str] = Field(title="sha256 hash") sha256: Optional[str] = Field(title="sha256 hash")
filename: str = Field(title="Filename") filename: str = Field(title="Filename")
config: str = Field(title="Config file") config: Optional[str] = Field(title="Config file")
class HypernetworkItem(BaseModel): class HypernetworkItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")

View File

@ -34,14 +34,18 @@ def get_cuda_device_string():
return "cuda" return "cuda"
def get_optimal_device(): def get_optimal_device_name():
if torch.cuda.is_available(): if torch.cuda.is_available():
return torch.device(get_cuda_device_string()) return get_cuda_device_string()
if has_mps(): if has_mps():
return torch.device("mps") return "mps"
return cpu return "cpu"
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task): def get_device_for(task):
@ -139,6 +143,8 @@ def test_for_nans(x, where):
else: else:
message = "A tensor with all NaNs was produced." message = "A tensor with all NaNs was produced."
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message) raise NansException(message)

View File

@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
@ -172,7 +172,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
conditioning = torch.nn.functional.interpolate( conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in), self.sd_model.depth_model(midas_in),
@ -185,7 +185,12 @@ class StableDiffusionProcessing:
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
return conditioning return conditioning
def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None): def edit_image_conditioning(self, source_image):
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
return conditioning_image
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
self.is_using_inpainting_conditioning = True self.is_using_inpainting_conditioning = True
# Handle the different mask inputs # Handle the different mask inputs
@ -212,7 +217,7 @@ class StableDiffusionProcessing:
) )
# Encode the new masked image using first stage of network. # Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat` # Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@ -228,6 +233,9 @@ class StableDiffusionProcessing:
if isinstance(self.sd_model, LatentDepth2ImageDiffusion): if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
if self.sd_model.cond_stage_key == "edit":
return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}: if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
@ -409,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
def decode_first_stage(model, x): def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae): with devices.autocast(disable=x.dtype == devices.dtype_vae):
x = model.decode_first_stage(x) x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
return x return x
@ -650,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
if p.scripts is not None:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
if p.color_corrections is not None and i < len(p.color_corrections): if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
@ -993,7 +1006,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = torch.from_numpy(batch_images) image = torch.from_numpy(batch_images)
image = 2. * image - 1. image = 2. * image - 1.
image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

View File

@ -6,12 +6,16 @@ from collections import namedtuple
import gradio as gr import gradio as gr
from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
AlwaysVisible = object() AlwaysVisible = object()
class PostprocessImageArgs:
def __init__(self, image):
self.image = image
class Script: class Script:
filename = None filename = None
args_from = None args_from = None
@ -65,7 +69,7 @@ class Script:
args contains all values returned by components from ui() args contains all values returned by components from ui()
""" """
raise NotImplementedError() pass
def process(self, p, *args): def process(self, p, *args):
""" """
@ -100,6 +104,13 @@ class Script:
pass pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
"""
pass
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
""" """
This function is called after processing ends for AlwaysVisible scripts. This function is called after processing ends for AlwaysVisible scripts.
@ -247,11 +258,15 @@ class ScriptRunner:
self.infotext_fields = [] self.infotext_fields = []
def initialize_scripts(self, is_img2img): def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
self.scripts.clear() self.scripts.clear()
self.alwayson_scripts.clear() self.alwayson_scripts.clear()
self.selectable_scripts.clear() self.selectable_scripts.clear()
for script_class, path, basedir, script_module in scripts_data: auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
script = script_class() script = script_class()
script.filename = path script.filename = path
script.is_txt2img = not is_img2img script.is_txt2img = not is_img2img
@ -332,7 +347,7 @@ class ScriptRunner:
return inputs return inputs
def run(self, p: StableDiffusionProcessing, *args): def run(self, p, *args):
script_index = args[0] script_index = args[0]
if script_index == 0: if script_index == 0:
@ -386,6 +401,15 @@ class ScriptRunner:
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def before_component(self, component, **kwargs): def before_component(self, component, **kwargs):
for script in self.scripts: for script in self.scripts:
try: try:

View File

@ -0,0 +1,42 @@
from modules import scripts, scripts_postprocessing, shared
class ScriptPostprocessingForMainUI(scripts.Script):
def __init__(self, script_postproc):
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
self.postprocessing_controls = None
def title(self):
return self.script.name
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.postprocessing_controls = self.script.ui()
return self.postprocessing_controls.values()
def postprocess_image(self, p, script_pp, *args):
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
pp.info = {}
self.script.process(pp, **args_dict)
p.extra_generation_params.update(pp.info)
script_pp.image = pp.image
def create_auto_preprocessing_script_data():
from modules import scripts
res = []
for name in shared.opts.postprocessing_enable_in_main_ui:
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
if script is None:
continue
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
return res

View File

@ -46,6 +46,8 @@ class ScriptPostprocessing:
pass pass
def wrap_call(func, filename, funcname, *args, default=None, **kwargs): def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try: try:
res = func(*args, **kwargs) res = func(*args, **kwargs)
@ -68,6 +70,9 @@ class ScriptPostprocessingRunner:
script: ScriptPostprocessing = script_class() script: ScriptPostprocessing = script_class()
script.filename = path script.filename = path
if script.name == "Simple Upscale":
continue
self.scripts.append(script) self.scripts.append(script)
def create_script_ui(self, script, inputs): def create_script_ui(self, script, inputs):
@ -87,12 +92,11 @@ class ScriptPostprocessingRunner:
import modules.scripts import modules.scripts
self.initialize_scripts(modules.scripts.postprocessing_scripts_data) self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] scripts_order = shared.opts.postprocessing_operation_order
def script_score(name): def script_score(name):
name = name.lower()
for i, possible_match in enumerate(scripts_order): for i, possible_match in enumerate(scripts_order):
if possible_match in name: if possible_match == name:
return i return i
return len(self.scripts) return len(self.scripts)
@ -145,3 +149,4 @@ class ScriptPostprocessingRunner:
def image_changed(self): def image_changed(self):
for script in self.scripts_in_preferred_order(): for script in self.scripts_in_preferred_order():
script.image_changed() script.image_changed()

View File

@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0, e_t return x_prev, pred_x0, e_t
def should_hijack_inpainting(checkpoint_info):
from modules import sd_models
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
def do_inpainting_hijack(): def do_inpainting_hijack():
# p_sample_plms is needed because PLMS can't work with dicts as conditionings # p_sample_plms is needed because PLMS can't work with dicts as conditionings

View File

@ -5,7 +5,7 @@ class CondFunc:
self = super(CondFunc, cls).__new__(cls) self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str): if isinstance(orig_func, str):
func_path = orig_func.split('.') func_path = orig_func.split('.')
for i in range(len(func_path)-2, -1, -1): for i in range(len(func_path)-1, -1, -1):
try: try:
resolved_obj = importlib.import_module('.'.join(func_path[:i])) resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break break

View File

@ -2,8 +2,6 @@ import collections
import os.path import os.path
import sys import sys
import gc import gc
import time
from collections import namedtuple
import torch import torch
import re import re
import safetensors.torch import safetensors.torch
@ -14,10 +12,10 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.sd_hijack_ip2p import should_hijack_ip2p from modules.timer import Timer
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
@ -99,17 +97,6 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
def find_checkpoint_config(info):
if info is None:
return shared.cmd_opts.config
config = os.path.splitext(info.filename)[0] + ".yaml"
if os.path.exists(config):
return config
return shared.cmd_opts.config
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
checkpoint_alisases.clear() checkpoint_alisases.clear()
@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
if device is None:
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else: else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@ -229,32 +214,44 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd return sd
def load_model_weights(model, checkpoint_info: CheckpointInfo): def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if checkpoint_info in checkpoints_loaded:
# use checkpoint cache
print(f"Loading weights [{sd_model_hash}] from cache")
return checkpoints_loaded[checkpoint_info]
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
res = read_state_dict(checkpoint_info.filename)
timer.record("load weights from disk")
return res
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
title = checkpoint_info.title title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash() sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if checkpoint_info.title != title: if checkpoint_info.title != title:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
cache_enabled = shared.opts.sd_checkpoint_cache > 0 if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
if cache_enabled and checkpoint_info in checkpoints_loaded: model.load_state_dict(state_dict, strict=False)
# use checkpoint cache del state_dict
print(f"Loading weights [{sd_model_hash}] from cache") timer.record("apply weights to model")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
sd = read_state_dict(checkpoint_info.filename) if shared.opts.sd_checkpoint_cache > 0:
model.load_state_dict(sd, strict=False)
del sd
if cache_enabled:
# cache newly loaded model # cache newly loaded model
checkpoints_loaded[checkpoint_info] = model.state_dict().copy() checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
if shared.cmd_opts.opt_channelslast: if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
vae = model.first_stage_model vae = model.first_stage_model
@ -272,17 +269,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
if depth_model: if depth_model:
model.depth_model = depth_model model.depth_model = depth_model
timer.record("apply half()")
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
devices.dtype_unet = model.model.diffusion_model.dtype devices.dtype_unet = model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
timer.record("apply dtype to VAE")
# clean up cache if limit is reached # clean up cache if limit is reached
if cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model checkpoints_loaded.popitem(last=False)
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_info.filename model.sd_model_checkpoint = checkpoint_info.filename
@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
sd_vae.clear_loaded_vae() sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
sd_vae.load_vae(model, vae_file, vae_source) sd_vae.load_vae(model, vae_file, vae_source)
timer.record("load VAE")
def enable_midas_autodownload(): def enable_midas_autodownload():
@ -340,24 +340,20 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper midas.api.load_model = load_model_wrapper
class Timer: def repair_config(sd_config):
def __init__(self):
self.start = time.time()
def elapsed(self): if not hasattr(sd_config.model.params, "use_ema"):
end = time.time() sd_config.model.params.use_ema = False
res = end - self.start
self.start = end if shared.cmd_opts.no_half:
return res sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
checkpoint_config = find_checkpoint_config(checkpoint_info)
if checkpoint_config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_config}")
if shared.sd_model: if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@ -365,38 +361,27 @@ def load_model(checkpoint_info=None):
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None
if should_hijack_ip2p(checkpoint_info):
sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion"
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.first_stage_key = "edited"
sd_config.model.params.cond_stage_key = "edit"
sd_config.model.params.image_size = 16
sd_config.model.params.unet_config.params.in_channels = 8
sd_config.model.params.unet_config.params.out_channels = 4
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
do_inpainting_hijack() do_inpainting_hijack()
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True
timer = Timer() timer = Timer()
sd_model = None if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
timer.record("find config")
sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config)
timer.record("load config")
print(f"Creating model from config: {checkpoint_config}")
sd_model = None
try: try:
with sd_disable_initialization.DisableInitialization(): with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
@ -407,29 +392,35 @@ def load_model(checkpoint_info=None):
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
elapsed_create = timer.elapsed() sd_model.used_config = checkpoint_config
load_model_weights(sd_model, checkpoint_info) timer.record("create model")
elapsed_load_weights = timer.elapsed() load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else: else:
sd_model.to(shared.device) sd_model.to(shared.device)
timer.record("move model to device")
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model shared.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
timer.record("load textual inversion embeddings")
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
elapsed_the_rest = timer.elapsed() timer.record("scripts callbacks")
print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") print(f"Model loaded in {timer.summary()}.")
return sd_model return sd_model
@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
if sd_model is None: # previous model load failed if sd_model is None: # previous model load failed
current_checkpoint_info = None current_checkpoint_info = None
else: else:
@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
else: else:
@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None):
timer = Timer() timer = Timer()
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
timer.record("find config")
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
return shared.sd_model
try: try:
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
except Exception as e: except Exception as e:
print("Failed to load checkpoint, restoring previous") print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info) load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise raise
finally: finally:
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
timer.record("move model to device")
elapsed = timer.elapsed() print(f"Weights loaded in {timer.summary()}.")
print(f"Weights loaded in {elapsed:.1f}s.")
return sd_model return sd_model

View File

@ -0,0 +1,68 @@
import re
import os
from modules import shared, paths
sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
re_parametrization_v = re.compile(r'-v\b')
def guess_model_config_from_state_dict(sd, filename):
fn = os.path.basename(filename)
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if re.search(re_parametrization_v, fn) or "v2-1_768" in fn:
return config_sd2v
else:
return config_sd2
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
return config_inpainting
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
return config_alt_diffusion
return config_default
def find_checkpoint_config(state_dict, info):
if info is None:
return guess_model_config_from_state_dict(state_dict, "")
config = find_checkpoint_config_near_filename(info)
if config is not None:
return config
return guess_model_config_from_state_dict(state_dict, info.filename)
def find_checkpoint_config_near_filename(info):
if info is None:
return None
config = os.path.splitext(info.filename)[0] + ".yaml"
if os.path.exists(config):
return config
return None

View File

@ -454,7 +454,7 @@ class KDiffusionSampler:
def initialize(self, p): def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0 self.model_wrap_cfg.step = 0
self.eta = p.eta or opts.eta_ancestral self.eta = p.eta or opts.eta_ancestral
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])

View File

@ -13,13 +13,14 @@ import modules.interrogate
import modules.memmon import modules.memmon
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components from modules import localization, extensions, script_loading, errors, ui_components, shared_items
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path
demo = None demo = None
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
@ -264,12 +265,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = [] face_restorers = []
def realesrgan_models_names():
import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
self.default = default self.default = default
@ -360,7 +355,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
})) }))
@ -397,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
@ -483,7 +478,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
})) }))
options_templates.update(options_section(('postprocessing', "Postprocessing"), { options_templates.update(options_section(('postprocessing', "Postprocessing"), {
'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
})) }))

23
modules/shared_items.py Normal file
View File

@ -0,0 +1,23 @@
def realesrgan_models_names():
import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
def postprocessing_scripts():
import modules.scripts
return modules.scripts.scripts_postproc.scripts
def sd_vae_items():
import modules.sd_vae
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
def refresh_vae_list():
import modules.sd_vae
return modules.sd_vae.refresh_vae_list

View File

@ -194,7 +194,7 @@ class EmbeddingDatabase:
if not os.path.isdir(embdir.path): if not os.path.isdir(embdir.path):
return return
for root, dirs, fns in os.walk(embdir.path): for root, dirs, fns in os.walk(embdir.path, followlinks=True):
for fn in fns: for fn in fns:
try: try:
fullfn = os.path.join(root, fn) fullfn = os.path.join(root, fn)

35
modules/timer.py Normal file
View File

@ -0,0 +1,35 @@
import time
class Timer:
def __init__(self):
self.start = time.time()
self.records = {}
self.total = 0
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def record(self, category, extra_time=0):
e = self.elapsed()
if category not in self.records:
self.records[category] = 0
self.records[category] += e + extra_time
self.total += e + extra_time
def summary(self):
res = f"{self.total:.1f}s"
additions = [x for x in self.records.items() if x[1] >= 0.1]
if not additions:
return res
res += " ("
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
res += ")"
return res

View File

@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
def get_block_name(self): def get_block_name(self):
return "colorpicker" return "colorpicker"
class DropdownMulti(gr.Dropdown):
"""Same as gr.Dropdown but always multiselect"""
def __init__(self, **kwargs):
super().__init__(multiselect=True, **kwargs)
def get_block_name(self):
return "dropdown"

View File

@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def image_changed(self): def image_changed(self):
upscale_cache.clear() upscale_cache.clear()
class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
name = "Simple Upscale"
order = 900
def ui(self):
with FormRow():
upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
return {
"upscale_by": upscale_by,
"upscaler_name": upscaler_name,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
if upscaler_name is None or upscaler_name == "None":
return
upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
assert upscaler1, f'could not find upscaler named {upscaler_name}'
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
pp.info[f"Postprocess upscaler"] = upscaler1.name

View File

@ -164,7 +164,7 @@
min-height: 3.2em; min-height: 3.2em;
} }
#txt2img_styles ul, #img2img_styles ul{ ul.list-none{
max-height: 35em; max-height: 35em;
z-index: 2000; z-index: 2000;
} }
@ -714,9 +714,6 @@ footer {
white-space: nowrap; white-space: nowrap;
min-width: auto; min-width: auto;
} }
#txt2img_hires_fix{
margin-left: -0.8em;
}
#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ #img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
margin-left: 0em; margin-left: 0em;
@ -744,7 +741,6 @@ footer {
.dark .gr-compact{ .dark .gr-compact{
background-color: rgb(31 41 55 / var(--tw-bg-opacity)); background-color: rgb(31 41 55 / var(--tw-bg-opacity));
margin-left: 0.8em;
} }
.gr-compact{ .gr-compact{

View File

@ -10,7 +10,7 @@ then
fi fi
export install_dir="$HOME" export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"