make it possible for StableDiffusionProcessing to accept multiple different negative prompts in a batch
This commit is contained in:
parent
e35d8b493f
commit
617c5b486f
@ -124,6 +124,7 @@ class StableDiffusionProcessing():
|
|||||||
self.scripts = None
|
self.scripts = None
|
||||||
self.script_args = None
|
self.script_args = None
|
||||||
self.all_prompts = None
|
self.all_prompts = None
|
||||||
|
self.all_negative_prompts = None
|
||||||
self.all_seeds = None
|
self.all_seeds = None
|
||||||
self.all_subseeds = None
|
self.all_subseeds = None
|
||||||
|
|
||||||
@ -202,7 +203,7 @@ class StableDiffusionProcessing():
|
|||||||
|
|
||||||
|
|
||||||
class Processed:
|
class Processed:
|
||||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||||
self.images = images_list
|
self.images = images_list
|
||||||
self.prompt = p.prompt
|
self.prompt = p.prompt
|
||||||
self.negative_prompt = p.negative_prompt
|
self.negative_prompt = p.negative_prompt
|
||||||
@ -241,16 +242,18 @@ class Processed:
|
|||||||
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
||||||
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
||||||
|
|
||||||
self.all_prompts = all_prompts or [self.prompt]
|
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
|
||||||
self.all_seeds = all_seeds or [self.seed]
|
self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
|
||||||
self.all_subseeds = all_subseeds or [self.subseed]
|
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
||||||
|
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
||||||
self.infotexts = infotexts or [info]
|
self.infotexts = infotexts or [info]
|
||||||
|
|
||||||
def js(self):
|
def js(self):
|
||||||
obj = {
|
obj = {
|
||||||
"prompt": self.prompt,
|
"prompt": self.all_prompts[0],
|
||||||
"all_prompts": self.all_prompts,
|
"all_prompts": self.all_prompts,
|
||||||
"negative_prompt": self.negative_prompt,
|
"negative_prompt": self.all_negative_prompts[0],
|
||||||
|
"all_negative_prompts": self.all_negative_prompts,
|
||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
"all_seeds": self.all_seeds,
|
"all_seeds": self.all_seeds,
|
||||||
"subseed": self.subseed,
|
"subseed": self.subseed,
|
||||||
@ -411,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else ""
|
||||||
|
|
||||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
@ -440,10 +443,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
else:
|
else:
|
||||||
assert p.prompt is not None
|
assert p.prompt is not None
|
||||||
|
|
||||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
|
||||||
processed = Processed(p, [], p.seed, "")
|
|
||||||
file.write(processed.infotext(p, 0))
|
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
seed = get_fixed_seed(p.seed)
|
seed = get_fixed_seed(p.seed)
|
||||||
@ -453,15 +452,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
modules.sd_hijack.model_hijack.clear_comments()
|
modules.sd_hijack.model_hijack.clear_comments()
|
||||||
|
|
||||||
comments = {}
|
comments = {}
|
||||||
prompt_tmp = p.prompt
|
|
||||||
negative_prompt_tmp = p.negative_prompt
|
|
||||||
|
|
||||||
shared.prompt_styles.apply_styles(p)
|
|
||||||
|
|
||||||
if type(p.prompt) == list:
|
if type(p.prompt) == list:
|
||||||
p.all_prompts = p.prompt
|
p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
|
||||||
else:
|
else:
|
||||||
p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
|
||||||
|
|
||||||
|
if type(p.negative_prompt) == list:
|
||||||
|
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
|
||||||
|
else:
|
||||||
|
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
|
||||||
|
|
||||||
if type(seed) == list:
|
if type(seed) == list:
|
||||||
p.all_seeds = seed
|
p.all_seeds = seed
|
||||||
@ -476,6 +476,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
def infotext(iteration=0, position_in_batch=0):
|
def infotext(iteration=0, position_in_batch=0):
|
||||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||||
|
|
||||||
|
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||||
|
processed = Processed(p, [], p.seed, "")
|
||||||
|
file.write(processed.infotext(p, 0))
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
@ -500,6 +504,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
break
|
break
|
||||||
|
|
||||||
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
@ -510,7 +515,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
|
||||||
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
||||||
|
|
||||||
if len(model_hijack.comments) > 0:
|
if len(model_hijack.comments) > 0:
|
||||||
@ -596,14 +601,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess(p, res)
|
p.scripts.postprocess(p, res)
|
||||||
|
|
||||||
p.prompt = prompt_tmp
|
|
||||||
p.negative_prompt = negative_prompt_tmp
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,17 +65,6 @@ class StyleDatabase:
|
|||||||
def apply_negative_styles_to_prompt(self, prompt, styles):
|
def apply_negative_styles_to_prompt(self, prompt, styles):
|
||||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
||||||
|
|
||||||
def apply_styles(self, p: StableDiffusionProcessing) -> None:
|
|
||||||
if isinstance(p.prompt, list):
|
|
||||||
p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
|
|
||||||
else:
|
|
||||||
p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
|
|
||||||
|
|
||||||
if isinstance(p.negative_prompt, list):
|
|
||||||
p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
|
|
||||||
else:
|
|
||||||
p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
|
|
||||||
|
|
||||||
def save_styles(self, path: str) -> None:
|
def save_styles(self, path: str) -> None:
|
||||||
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
||||||
fd, temp_path = tempfile.mkstemp(".csv")
|
fd, temp_path = tempfile.mkstemp(".csv")
|
||||||
|
Loading…
Reference in New Issue
Block a user