From fe4e3c267391b4745efe562967f3c8b0b1037a0d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 8 Sep 2022 19:34:20 +0300 Subject: [PATCH] fix for PLMS live previews in txt2img --- modules/sd_samplers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index fd63e47f..6b7979e2 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -87,7 +87,7 @@ ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdq class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) - self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else None + self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms self.mask = None self.nmask = None self.init_latent = None @@ -113,7 +113,9 @@ class VanillaStableDiffusionSampler: return samples def sample(self, p, x, conditioning, unconditional_conditioning): - self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs) + for fieldname in ['p_sample_ddim', 'p_sample_plms']: + if hasattr(self.sampler, fieldname): + setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)) self.mask = None self.nmask = None self.init_latent = None