Fix logspam and live previews

This commit is contained in:
space-nuko 2023-02-10 04:47:08 -08:00
parent 1253199889
commit 21880eb9e5
3 changed files with 41 additions and 31 deletions

View File

@ -19,9 +19,10 @@ class UniPCSampler(object):
attr = attr.to(torch.device("cuda")) attr = attr.to(torch.device("cuda"))
setattr(self, name, attr) setattr(self, name, attr)
def set_hooks(self, before, after): def set_hooks(self, before_sample, after_sample, after_update):
self.before_sample = before self.before_sample = before_sample
self.after_sample = after self.after_sample = after_sample
self.after_update = after_update
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(self,
@ -50,9 +51,17 @@ class UniPCSampler(object):
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else: else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
@ -60,6 +69,7 @@ class UniPCSampler(object):
# sampling # sampling
C, H, W = shape C, H, W = shape
size = (batch_size, C, H, W) size = (batch_size, C, H, W)
print(f'Data shape for UniPC sampling is {size}, eta {eta}')
device = self.model.betas.device device = self.model.betas.device
if x_T is None: if x_T is None:
@ -79,7 +89,7 @@ class UniPCSampler(object):
guidance_scale=unconditional_guidance_scale, guidance_scale=unconditional_guidance_scale,
) )
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
return x.to(device), None return x.to(device), None

View File

@ -378,7 +378,8 @@ class UniPC:
condition=None, condition=None,
unconditional_condition=None, unconditional_condition=None,
before_sample=None, before_sample=None,
after_sample=None after_sample=None,
after_update=None
): ):
"""Construct a UniPC. """Construct a UniPC.
@ -394,6 +395,7 @@ class UniPC:
self.unconditional_condition = unconditional_condition self.unconditional_condition = unconditional_condition
self.before_sample = before_sample self.before_sample = before_sample
self.after_sample = after_sample self.after_sample = after_sample
self.after_update = after_update
def dynamic_thresholding_fn(self, x0, t=None): def dynamic_thresholding_fn(self, x0, t=None):
""" """
@ -434,15 +436,6 @@ class UniPC:
noise = self.noise_prediction_fn(x, t) noise = self.noise_prediction_fn(x, t)
dims = x.dim() dims = x.dim()
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
from pprint import pp
print("X:")
pp(x)
print("sigma_t:")
pp(sigma_t)
print("noise:")
pp(noise)
print("alpha_t:")
pp(alpha_t)
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
if self.thresholding: if self.thresholding:
p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
@ -524,7 +517,7 @@ class UniPC:
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns = self.noise_schedule ns = self.noise_schedule
assert order <= len(model_prev_list) assert order <= len(model_prev_list)
@ -568,7 +561,7 @@ class UniPC:
A_p = C_inv_p A_p = C_inv_p
if use_corrector: if use_corrector:
print('using corrector') #print('using corrector')
C_inv = torch.linalg.inv(C) C_inv = torch.linalg.inv(C)
A_c = C_inv A_c = C_inv
@ -627,7 +620,7 @@ class UniPC:
return x_t, model_t return x_t, model_t
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
print(f'using unified predictor-corrector with order {order} (solver type: B(h))') #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
ns = self.noise_schedule ns = self.noise_schedule
assert order <= len(model_prev_list) assert order <= len(model_prev_list)
dims = x.dim() dims = x.dim()
@ -695,7 +688,7 @@ class UniPC:
D1s = None D1s = None
if use_corrector: if use_corrector:
print('using corrector') #print('using corrector')
# for order 1, we use a simplified version # for order 1, we use a simplified version
if order == 1: if order == 1:
rhos_c = torch.tensor([0.5], device=b.device) rhos_c = torch.tensor([0.5], device=b.device)
@ -755,8 +748,9 @@ class UniPC:
t_T = self.noise_schedule.T if t_start is None else t_start t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device device = x.device
if method == 'multistep': if method == 'multistep':
assert steps >= order assert steps >= order, "UniPC order must be < sampling steps"
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps")
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
with torch.no_grad(): with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0])) vec_t = timesteps[0].expand((x.shape[0]))
@ -768,6 +762,8 @@ class UniPC:
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
if model_x is None: if model_x is None:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
if self.after_update is not None:
self.after_update(x, model_x)
model_prev_list.append(model_x) model_prev_list.append(model_x)
t_prev_list.append(vec_t) t_prev_list.append(vec_t)
for step in range(order, steps + 1): for step in range(order, steps + 1):
@ -776,13 +772,15 @@ class UniPC:
step_order = min(order, steps + 1 - step) step_order = min(order, steps + 1 - step)
else: else:
step_order = order step_order = order
print('this step order:', step_order) #print('this step order:', step_order)
if step == steps: if step == steps:
print('do not run corrector at the last step') #print('do not run corrector at the last step')
use_corrector = False use_corrector = False
else: else:
use_corrector = True use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
if self.after_update is not None:
self.after_update(x, model_x)
for i in range(order - 1): for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1] t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1]

View File

@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler:
return x, ts, cond, unconditional_conditioning return x, ts, cond, unconditional_conditioning
def after_sample(self, x, ts, cond, uncond, res): def update_step(self, last_latent):
if self.is_unipc:
# unipc model_fn returns (pred_x0)
# p_sample_ddim returns (x_prev, pred_x0)
res = (None, res[0])
if self.mask is not None: if self.mask is not None:
self.last_latent = self.init_latent * self.mask + self.nmask * res[1] self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
else: else:
self.last_latent = res[1] self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent) sd_samplers_common.store_latent(self.last_latent)
@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler:
state.sampling_step = self.step state.sampling_step = self.step
shared.total_tqdm.update() shared.total_tqdm.update()
def after_sample(self, x, ts, cond, uncond, res):
if not self.is_unipc:
self.update_step(res[1])
return x, ts, cond, uncond, res return x, ts, cond, uncond, res
def unipc_after_update(self, x, model_x):
self.update_step(x)
def initialize(self, p): def initialize(self, p):
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
if self.eta != 0.0: if self.eta != 0.0:
@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler:
if hasattr(self.sampler, fieldname): if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook) setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
if self.is_unipc: if self.is_unipc:
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r)) self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None