diff --git a/modules/processing.py b/modules/processing.py index 93138e7c..9b53d210 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices +from modules import devices, prompt_parser from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state @@ -247,8 +247,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] - uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) - c = p.sd_model.get_learned_conditioning(prompts) + #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) + #c = p.sd_model.get_learned_conditioning(prompts) + uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) + c = prompt_parser.get_learned_conditioning(prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py new file mode 100644 index 00000000..e918fabf --- /dev/null +++ b/modules/prompt_parser.py @@ -0,0 +1,128 @@ +import re +from collections import namedtuple +import torch + +import modules.shared as shared + +re_prompt = re.compile(r''' +(.*?) +\[ + ([^]:]+): + (?:([^]:]*):)? + ([0-9]*\.?[0-9]+) +] +| +(.+) +''', re.X) + +# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" +# will be represented with prompt_schedule like this (assuming steps=100): +# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] +# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] +# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful'] +# [75, 'fantasy landscape with a lake and an oak in background masterful'] +# [100, 'fantasy landscape with a lake and a christmas tree in background masterful'] + + +def get_learned_conditioning_prompt_schedules(prompts, steps): + res = [] + cache = {} + + for prompt in prompts: + prompt_schedule: list[list[str | int]] = [[steps, ""]] + + cached = cache.get(prompt, None) + if cached is not None: + res.append(cached) + + for m in re_prompt.finditer(prompt): + plaintext = m.group(1) if m.group(5) is None else m.group(5) + concept_from = m.group(2) + concept_to = m.group(3) + if concept_to is None: + concept_to = concept_from + concept_from = "" + swap_position = float(m.group(4)) if m.group(4) is not None else None + + if swap_position is not None: + if swap_position < 1: + swap_position = swap_position * steps + swap_position = int(min(swap_position, steps)) + + swap_index = None + found_exact_index = False + for i in range(len(prompt_schedule)): + end_step = prompt_schedule[i][0] + prompt_schedule[i][1] += plaintext + + if swap_position is not None and swap_index is None: + if swap_position == end_step: + swap_index = i + found_exact_index = True + + if swap_position < end_step: + swap_index = i + + if swap_index is not None: + if not found_exact_index: + prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]]) + + for i in range(len(prompt_schedule)): + end_step = prompt_schedule[i][0] + must_replace = swap_position < end_step + + prompt_schedule[i][1] += concept_to if must_replace else concept_from + + res.append(prompt_schedule) + cache[prompt] = prompt_schedule + #for t in prompt_schedule: + # print(t) + + return res + + +ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) +ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"]) + + +def get_learned_conditioning(prompts, steps): + + res = [] + + prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) + cache = {} + + for prompt, prompt_schedule in zip(prompts, prompt_schedules): + + cached = cache.get(prompt, None) + if cached is not None: + res.append(cached) + + texts = [x[1] for x in prompt_schedule] + conds = shared.sd_model.get_learned_conditioning(texts) + + cond_schedule = [] + for i, (end_at_step, text) in enumerate(prompt_schedule): + cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i])) + + cache[prompt] = cond_schedule + res.append(cond_schedule) + + return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res) + + +def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): + res = torch.zeros(c.shape) + for i, cond_schedule in enumerate(c.schedules): + target_index = 0 + for curret_index, (end_at, cond) in enumerate(cond_schedule): + if current_step <= end_at: + target_index = curret_index + break + res[i] = cond_schedule[target_index].cond + + return res.to(shared.device) + + + +#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 7ef507f1..c042c5c3 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,6 +7,7 @@ from PIL import Image import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms +from modules import prompt_parser from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -53,20 +54,6 @@ def store_latent(decoded): shared.state.current_image = sample_to_image(decoded) -def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs): - if sampler_wrapper.mask is not None: - img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts) - x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec - - res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs) - - if sampler_wrapper.mask is not None: - store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1]) - else: - store_latent(res[1]) - - return res - def extended_tdqm(sequence, *args, desc=None, **kwargs): state.sampling_steps = len(sequence) @@ -93,6 +80,25 @@ class VanillaStableDiffusionSampler: self.mask = None self.nmask = None self.init_latent = None + self.step = 0 + + def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): + cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + + if self.mask is not None: + img_orig = self.sampler.model.q_sample(self.init_latent, ts) + x_dec = img_orig * self.mask + self.nmask * x_dec + + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + + if self.mask is not None: + store_latent(self.init_latent * self.mask + self.nmask * res[1]) + else: + store_latent(res[1]) + + self.step += 1 + return res def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning): t_enc = int(min(p.denoising_strength, 0.999) * p.steps) @@ -105,7 +111,7 @@ class VanillaStableDiffusionSampler: x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs) + self.sampler.p_sample_ddim = self.p_sample_ddim_hook self.mask = p.mask self.nmask = p.nmask self.init_latent = p.init_latent @@ -117,7 +123,7 @@ class VanillaStableDiffusionSampler: def sample(self, p, x, conditioning, unconditional_conditioning): for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)) + setattr(self.sampler, fieldname, self.p_sample_ddim_hook) self.mask = None self.nmask = None self.init_latent = None @@ -138,8 +144,12 @@ class CFGDenoiser(torch.nn.Module): self.mask = None self.nmask = None self.init_latent = None + self.step = 0 def forward(self, x, sigma, uncond, cond, cond_scale): + cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + if shared.batch_cond_uncond: x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) @@ -154,6 +164,8 @@ class CFGDenoiser(torch.nn.Module): if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised + self.step += 1 + return denoised