added prompt verification: if it's too long, a warning is returned in the text field along with the part of prompt that has been truncated
This commit is contained in:
parent
e996f3c118
commit
7fd0f31661
34
webui.py
34
webui.py
@ -45,6 +45,7 @@ parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-i
|
||||
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default='./GFPGAN')
|
||||
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
|
||||
opt = parser.parse_args()
|
||||
|
||||
GFPGAN_dir = opt.gfpgan_dir
|
||||
@ -231,6 +232,25 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
return result
|
||||
|
||||
|
||||
def check_prompt_length(prompt, comments):
|
||||
"""this function tests if prompt is too long, and if so, adds a message to comments"""
|
||||
|
||||
tokenizer = model.cond_stage_model.tokenizer
|
||||
max_length = model.cond_stage_model.max_length
|
||||
|
||||
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
|
||||
ovf = info['overflowing_tokens'][0]
|
||||
overflowing_count = ovf.shape[0]
|
||||
if overflowing_count == 0:
|
||||
return
|
||||
|
||||
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
|
||||
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
|
||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN):
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
@ -248,6 +268,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
comments = []
|
||||
|
||||
prompt_matrix_parts = []
|
||||
if prompt_matrix:
|
||||
all_prompts = []
|
||||
@ -267,6 +289,15 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||
|
||||
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
|
||||
else:
|
||||
|
||||
if not opt.no_verify_input:
|
||||
try:
|
||||
check_prompt_length(prompt, comments)
|
||||
except:
|
||||
import traceback
|
||||
print("Error verifying input:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
all_prompts = batch_size * n_iter * [prompt]
|
||||
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||
|
||||
@ -333,6 +364,9 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
||||
""".strip()
|
||||
|
||||
for comment in comments:
|
||||
info += "\n\n" + comment
|
||||
|
||||
return output_images, seed, info
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user