ba6a4e7e94
If image_cfg_scale is =1 then the original image is not used for the output. We can then use the original CFGDenoiser to get the same result to support AND functionality. Maybe in the future AND can be supported with "Image CFG Scale"
391 lines
17 KiB
Python
391 lines
17 KiB
Python
from collections import deque
|
|
import torch
|
|
import inspect
|
|
import einops
|
|
import k_diffusion.sampling
|
|
from modules import prompt_parser, devices, sd_samplers_common
|
|
|
|
from modules.shared import opts, state
|
|
import modules.shared as shared
|
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
|
|
|
samplers_k_diffusion = [
|
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
|
('Heun', 'sample_heun', ['k_heun'], {}),
|
|
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
|
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
|
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
|
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
|
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
|
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
|
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
|
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
|
]
|
|
|
|
samplers_data_k_diffusion = [
|
|
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
|
for label, funcname, aliases, options in samplers_k_diffusion
|
|
if hasattr(k_diffusion.sampling, funcname)
|
|
]
|
|
|
|
sampler_extra_params = {
|
|
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
|
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
|
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
|
}
|
|
|
|
class CFGDenoiserEdit(torch.nn.Module):
|
|
"""
|
|
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
|
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
|
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
|
negative prompt.
|
|
"""
|
|
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.inner_model = model
|
|
self.mask = None
|
|
self.nmask = None
|
|
self.init_latent = None
|
|
self.step = 0
|
|
|
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_cfg_scale):
|
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
|
denoised = torch.clone(denoised_uncond)
|
|
|
|
for i, conds in enumerate(conds_list):
|
|
for cond_index, weight in conds:
|
|
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
|
denoised[i] = out_uncond[cond_index] + cond_scale * (out_cond[cond_index] - out_img_cond[cond_index]) + image_cfg_scale * (out_img_cond[cond_index] - out_uncond[cond_index])
|
|
|
|
return denoised
|
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_cfg_scale):
|
|
if state.interrupted or state.skipped:
|
|
raise sd_samplers_common.InterruptedException
|
|
|
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
|
|
|
batch_size = len(conds_list)
|
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
|
|
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
|
|
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
|
cfg_denoiser_callback(denoiser_params)
|
|
x_in = denoiser_params.x
|
|
image_cond_in = denoiser_params.image_cond
|
|
sigma_in = denoiser_params.sigma
|
|
|
|
if tensor.shape[1] == uncond.shape[1]:
|
|
cond_in = torch.cat([tensor, uncond, uncond])
|
|
|
|
if shared.batch_cond_uncond:
|
|
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
|
else:
|
|
x_out = torch.zeros_like(x_in)
|
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
|
a = batch_offset
|
|
b = a + batch_size
|
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
|
else:
|
|
x_out = torch.zeros_like(x_in)
|
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
|
a = batch_offset
|
|
b = min(a + batch_size, tensor.shape[0])
|
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": torch.cat([tensor[a:b]], uncond) , "c_concat": [image_cond_in[a:b]]})
|
|
|
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
|
|
|
devices.test_for_nans(x_out, "unet")
|
|
|
|
if opts.live_preview_content == "Prompt":
|
|
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
|
elif opts.live_preview_content == "Negative prompt":
|
|
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
|
|
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_cfg_scale)
|
|
|
|
if self.mask is not None:
|
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
|
|
|
self.step += 1
|
|
|
|
return denoised
|
|
|
|
|
|
class CFGDenoiser(torch.nn.Module):
|
|
"""
|
|
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
|
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
|
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
|
negative prompt.
|
|
"""
|
|
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.inner_model = model
|
|
self.mask = None
|
|
self.nmask = None
|
|
self.init_latent = None
|
|
self.step = 0
|
|
|
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
|
denoised = torch.clone(denoised_uncond)
|
|
|
|
for i, conds in enumerate(conds_list):
|
|
for cond_index, weight in conds:
|
|
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
|
|
|
return denoised
|
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
|
if state.interrupted or state.skipped:
|
|
raise sd_samplers_common.InterruptedException
|
|
|
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
|
|
|
batch_size = len(conds_list)
|
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
|
|
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
|
|
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
|
cfg_denoiser_callback(denoiser_params)
|
|
x_in = denoiser_params.x
|
|
image_cond_in = denoiser_params.image_cond
|
|
sigma_in = denoiser_params.sigma
|
|
|
|
if tensor.shape[1] == uncond.shape[1]:
|
|
cond_in = torch.cat([tensor, uncond])
|
|
|
|
if shared.batch_cond_uncond:
|
|
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
|
else:
|
|
x_out = torch.zeros_like(x_in)
|
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
|
a = batch_offset
|
|
b = a + batch_size
|
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
|
else:
|
|
x_out = torch.zeros_like(x_in)
|
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
|
a = batch_offset
|
|
b = min(a + batch_size, tensor.shape[0])
|
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
|
|
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
|
|
|
devices.test_for_nans(x_out, "unet")
|
|
|
|
if opts.live_preview_content == "Prompt":
|
|
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
|
elif opts.live_preview_content == "Negative prompt":
|
|
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
|
|
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
|
|
|
if self.mask is not None:
|
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
|
|
|
self.step += 1
|
|
|
|
return denoised
|
|
|
|
|
|
class TorchHijack:
|
|
def __init__(self, sampler_noises):
|
|
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
|
# implementation.
|
|
self.sampler_noises = deque(sampler_noises)
|
|
|
|
def __getattr__(self, item):
|
|
if item == 'randn_like':
|
|
return self.randn_like
|
|
|
|
if hasattr(torch, item):
|
|
return getattr(torch, item)
|
|
|
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
|
|
|
def randn_like(self, x):
|
|
if self.sampler_noises:
|
|
noise = self.sampler_noises.popleft()
|
|
if noise.shape == x.shape:
|
|
return noise
|
|
|
|
if x.device.type == 'mps':
|
|
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
|
else:
|
|
return torch.randn_like(x)
|
|
|
|
|
|
class KDiffusionSampler:
|
|
def __init__(self, funcname, sd_model):
|
|
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
|
|
|
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
|
self.funcname = funcname
|
|
self.func = getattr(k_diffusion.sampling, self.funcname)
|
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
|
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
|
self.sampler_noises = None
|
|
self.stop_at = None
|
|
self.eta = None
|
|
self.config = None
|
|
self.last_latent = None
|
|
|
|
self.conditioning_key = sd_model.model.conditioning_key
|
|
|
|
def callback_state(self, d):
|
|
step = d['i']
|
|
latent = d["denoised"]
|
|
if opts.live_preview_content == "Combined":
|
|
sd_samplers_common.store_latent(latent)
|
|
self.last_latent = latent
|
|
|
|
if self.stop_at is not None and step > self.stop_at:
|
|
raise sd_samplers_common.InterruptedException
|
|
|
|
state.sampling_step = step
|
|
shared.total_tqdm.update()
|
|
|
|
def launch_sampling(self, steps, func):
|
|
state.sampling_steps = steps
|
|
state.sampling_step = 0
|
|
|
|
try:
|
|
return func()
|
|
except sd_samplers_common.InterruptedException:
|
|
return self.last_latent
|
|
|
|
def number_of_needed_noises(self, p):
|
|
return p.steps
|
|
|
|
def initialize(self, p):
|
|
if shared.sd_model.cond_stage_key == "edit" and getattr(p, 'image_cfg_scale', None) != 1:
|
|
self.model_wrap_cfg = CFGDenoiserEdit(self.model_wrap)
|
|
|
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
self.model_wrap_cfg.step = 0
|
|
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
|
|
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
|
|
|
extra_params_kwargs = {}
|
|
for param_name in self.extra_params:
|
|
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
|
extra_params_kwargs[param_name] = getattr(p, param_name)
|
|
|
|
if 'eta' in inspect.signature(self.func).parameters:
|
|
if self.eta != 1.0:
|
|
p.extra_generation_params["Eta"] = self.eta
|
|
|
|
extra_params_kwargs['eta'] = self.eta
|
|
|
|
return extra_params_kwargs
|
|
|
|
def get_sigmas(self, p, steps):
|
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
|
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
|
discard_next_to_last_sigma = True
|
|
p.extra_generation_params["Discard penultimate sigma"] = True
|
|
|
|
steps += 1 if discard_next_to_last_sigma else 0
|
|
|
|
if p.sampler_noise_scheduler_override:
|
|
sigmas = p.sampler_noise_scheduler_override(steps)
|
|
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
|
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
|
|
|
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
|
else:
|
|
sigmas = self.model_wrap.get_sigmas(steps)
|
|
|
|
if discard_next_to_last_sigma:
|
|
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
|
|
|
return sigmas
|
|
|
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
|
|
|
sigmas = self.get_sigmas(p, steps)
|
|
|
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
|
xi = x + noise * sigma_sched[0]
|
|
|
|
extra_params_kwargs = self.initialize(p)
|
|
if 'sigma_min' in inspect.signature(self.func).parameters:
|
|
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
|
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
|
if 'sigma_max' in inspect.signature(self.func).parameters:
|
|
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
|
if 'n' in inspect.signature(self.func).parameters:
|
|
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
|
if 'sigma_sched' in inspect.signature(self.func).parameters:
|
|
extra_params_kwargs['sigma_sched'] = sigma_sched
|
|
if 'sigmas' in inspect.signature(self.func).parameters:
|
|
extra_params_kwargs['sigmas'] = sigma_sched
|
|
|
|
self.model_wrap_cfg.init_latent = x
|
|
self.last_latent = x
|
|
extra_args={
|
|
'cond': conditioning,
|
|
'image_cond': image_conditioning,
|
|
'uncond': unconditional_conditioning,
|
|
'cond_scale': p.cfg_scale,
|
|
}
|
|
|
|
if hasattr(p, 'image_cfg_scale') and p.image_cfg_scale != 1 and p.image_cfg_scale != None:
|
|
extra_args['image_cfg_scale'] = p.image_cfg_scale
|
|
|
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
|
|
|
return samples
|
|
|
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
|
steps = steps or p.steps
|
|
|
|
sigmas = self.get_sigmas(p, steps)
|
|
|
|
x = x * sigmas[0]
|
|
|
|
extra_params_kwargs = self.initialize(p)
|
|
if 'sigma_min' in inspect.signature(self.func).parameters:
|
|
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
|
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
|
if 'n' in inspect.signature(self.func).parameters:
|
|
extra_params_kwargs['n'] = steps
|
|
else:
|
|
extra_params_kwargs['sigmas'] = sigmas
|
|
|
|
self.last_latent = x
|
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
|
'cond': conditioning,
|
|
'image_cond': image_conditioning,
|
|
'uncond': unconditional_conditioning,
|
|
'cond_scale': p.cfg_scale
|
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
|
|
|
return samples
|
|
|