added support for AND from https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
This commit is contained in:
parent
67d011b02e
commit
c26732fbee
@ -360,7 +360,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
#c = p.sd_model.get_learned_conditioning(prompts)
|
#c = p.sd_model.get_learned_conditioning(prompts)
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
||||||
c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps)
|
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
||||||
|
|
||||||
if len(model_hijack.comments) > 0:
|
if len(model_hijack.comments) > 0:
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
|
@ -97,10 +97,26 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
|
|
||||||
|
|
||||||
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||||
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(model, prompts, steps):
|
def get_learned_conditioning(model, prompts, steps):
|
||||||
|
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||||
|
and the sampling step at which this condition is to be replaced by the next one.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
|
||||||
|
|
||||||
|
Output:
|
||||||
|
[
|
||||||
|
[
|
||||||
|
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
|
||||||
|
],
|
||||||
|
[
|
||||||
|
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
|
||||||
|
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
|
||||||
|
]
|
||||||
|
]
|
||||||
|
"""
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
||||||
@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps):
|
|||||||
cache[prompt] = cond_schedule
|
cache[prompt] = cond_schedule
|
||||||
res.append(cond_schedule)
|
res.append(cond_schedule)
|
||||||
|
|
||||||
return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
|
return res
|
||||||
|
|
||||||
|
|
||||||
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
|
re_AND = re.compile(r"\bAND\b")
|
||||||
param = c.schedules[0][0].cond
|
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$")
|
||||||
res = torch.zeros(c.shape, device=param.device, dtype=param.dtype)
|
|
||||||
for i, cond_schedule in enumerate(c.schedules):
|
|
||||||
|
def get_multicond_prompt_list(prompts):
|
||||||
|
res_indexes = []
|
||||||
|
|
||||||
|
prompt_flat_list = []
|
||||||
|
prompt_indexes = {}
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
subprompts = re_AND.split(prompt)
|
||||||
|
|
||||||
|
indexes = []
|
||||||
|
for subprompt in subprompts:
|
||||||
|
text, weight = re_weight.search(subprompt).groups()
|
||||||
|
|
||||||
|
weight = float(weight) if weight is not None else 1.0
|
||||||
|
|
||||||
|
index = prompt_indexes.get(text, None)
|
||||||
|
if index is None:
|
||||||
|
index = len(prompt_flat_list)
|
||||||
|
prompt_flat_list.append(text)
|
||||||
|
prompt_indexes[text] = index
|
||||||
|
|
||||||
|
indexes.append((index, weight))
|
||||||
|
|
||||||
|
res_indexes.append(indexes)
|
||||||
|
|
||||||
|
return res_indexes, prompt_flat_list, prompt_indexes
|
||||||
|
|
||||||
|
|
||||||
|
class ComposableScheduledPromptConditioning:
|
||||||
|
def __init__(self, schedules, weight=1.0):
|
||||||
|
self.schedules: list[ScheduledPromptConditioning] = schedules
|
||||||
|
self.weight: float = weight
|
||||||
|
|
||||||
|
|
||||||
|
class MulticondLearnedConditioning:
|
||||||
|
def __init__(self, shape, batch):
|
||||||
|
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||||
|
self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
|
||||||
|
|
||||||
|
|
||||||
|
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
||||||
|
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||||
|
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||||
|
|
||||||
|
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
|
||||||
|
"""
|
||||||
|
|
||||||
|
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
||||||
|
|
||||||
|
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for indexes in res_indexes:
|
||||||
|
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
||||||
|
|
||||||
|
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
||||||
|
param = c[0][0].cond
|
||||||
|
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||||
|
for i, cond_schedule in enumerate(c):
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, (end_at, cond) in enumerate(cond_schedule):
|
for current, (end_at, cond) in enumerate(cond_schedule):
|
||||||
if current_step <= end_at:
|
if current_step <= end_at:
|
||||||
@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||||
|
param = c.batch[0][0].schedules[0].cond
|
||||||
|
|
||||||
|
tensors = []
|
||||||
|
conds_list = []
|
||||||
|
|
||||||
|
for batch_no, composable_prompts in enumerate(c.batch):
|
||||||
|
conds_for_batch = []
|
||||||
|
|
||||||
|
for cond_index, composable_prompt in enumerate(composable_prompts):
|
||||||
|
target_index = 0
|
||||||
|
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
||||||
|
if current_step <= end_at:
|
||||||
|
target_index = current
|
||||||
|
break
|
||||||
|
|
||||||
|
conds_for_batch.append((len(tensors), composable_prompt.weight))
|
||||||
|
tensors.append(composable_prompt.schedules[target_index].cond)
|
||||||
|
|
||||||
|
conds_list.append(conds_for_batch)
|
||||||
|
|
||||||
|
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
||||||
|
|
||||||
|
|
||||||
re_attention = re.compile(r"""
|
re_attention = re.compile(r"""
|
||||||
\\\(|
|
\\\(|
|
||||||
\\\)|
|
\\\)|
|
||||||
|
@ -109,9 +109,12 @@ class VanillaStableDiffusionSampler:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
|
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||||
|
cond = tensor
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||||
@ -183,19 +186,31 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
batch_size = len(conds_list)
|
||||||
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_in = torch.cat([x] * 2)
|
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
|
||||||
cond_in = torch.cat([uncond, cond])
|
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
||||||
denoised = uncond + (cond - uncond) * cond_scale
|
|
||||||
else:
|
else:
|
||||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
x_out = torch.zeros_like(x_in)
|
||||||
cond = self.inner_model(x, sigma, cond=cond)
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
denoised = uncond + (cond - uncond) * cond_scale
|
a = batch_offset
|
||||||
|
b = a + batch_size
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
||||||
|
|
||||||
|
denoised_uncond = x_out[-batch_size:]
|
||||||
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
|
||||||
|
for i, conds in enumerate(conds_list):
|
||||||
|
for cond_index, weight in conds:
|
||||||
|
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
@ -34,7 +34,7 @@ import modules.gfpgan_model
|
|||||||
import modules.codeformer_model
|
import modules.codeformer_model
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.generation_parameters_copypaste
|
import modules.generation_parameters_copypaste
|
||||||
from modules.prompt_parser import get_learned_conditioning_prompt_schedules
|
from modules import prompt_parser
|
||||||
from modules.images import apply_filename_pattern, get_next_sequence_number
|
from modules.images import apply_filename_pattern, get_next_sequence_number
|
||||||
import modules.textual_inversion.ui
|
import modules.textual_inversion.ui
|
||||||
|
|
||||||
@ -394,7 +394,9 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
|||||||
|
|
||||||
def update_token_counter(text, steps):
|
def update_token_counter(text, steps):
|
||||||
try:
|
try:
|
||||||
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
|
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||||
|
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# a parsing error can happen here during typing, and we don't want to bother the user with
|
# a parsing error can happen here during typing, and we don't want to bother the user with
|
||||||
# messages related to it in console
|
# messages related to it in console
|
||||||
|
Loading…
Reference in New Issue
Block a user