let user choose his own prompt token count limit
This commit is contained in:
parent
87db6f01cc
commit
706d5944a0
@ -123,6 +123,7 @@ class Processed:
|
|||||||
self.index_of_first_image = index_of_first_image
|
self.index_of_first_image = index_of_first_image
|
||||||
self.styles = p.styles
|
self.styles = p.styles
|
||||||
self.job_timestamp = state.job_timestamp
|
self.job_timestamp = state.job_timestamp
|
||||||
|
self.max_prompt_tokens = opts.max_prompt_tokens
|
||||||
|
|
||||||
self.eta = p.eta
|
self.eta = p.eta
|
||||||
self.ddim_discretize = p.ddim_discretize
|
self.ddim_discretize = p.ddim_discretize
|
||||||
@ -141,6 +142,7 @@ class Processed:
|
|||||||
self.all_subseeds = all_subseeds or [self.subseed]
|
self.all_subseeds = all_subseeds or [self.subseed]
|
||||||
self.infotexts = infotexts or [info]
|
self.infotexts = infotexts or [info]
|
||||||
|
|
||||||
|
|
||||||
def js(self):
|
def js(self):
|
||||||
obj = {
|
obj = {
|
||||||
"prompt": self.prompt,
|
"prompt": self.prompt,
|
||||||
@ -169,6 +171,7 @@ class Processed:
|
|||||||
"infotexts": self.infotexts,
|
"infotexts": self.infotexts,
|
||||||
"styles": self.styles,
|
"styles": self.styles,
|
||||||
"job_timestamp": self.job_timestamp,
|
"job_timestamp": self.job_timestamp,
|
||||||
|
"max_prompt_tokens": self.max_prompt_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.dumps(obj)
|
return json.dumps(obj)
|
||||||
@ -266,6 +269,8 @@ def fix_seed(p):
|
|||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
|
max_tokens = getattr(p, 'max_prompt_tokens', opts.max_prompt_tokens)
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": sd_samplers.samplers[p.sampler_index].name,
|
"Sampler": sd_samplers.samplers[p.sampler_index].name,
|
||||||
@ -281,6 +286,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||||
|
"Max tokens": (None if max_tokens == shared.vanilla_max_prompt_tokens else max_tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
@ -18,7 +18,6 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
|||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations():
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
@ -83,7 +82,7 @@ class StableDiffusionModelHijack:
|
|||||||
layer.padding_mode = 'circular' if enable else 'zeros'
|
layer.padding_mode = 'circular' if enable else 'zeros'
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
max_length = self.clip.max_length - 2
|
max_length = opts.max_prompt_tokens - 2
|
||||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||||
return remade_batch_tokens[0], token_count, max_length
|
return remade_batch_tokens[0], token_count, max_length
|
||||||
|
|
||||||
@ -94,7 +93,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.hijack: StableDiffusionModelHijack = hijack
|
self.hijack: StableDiffusionModelHijack = hijack
|
||||||
self.tokenizer = wrapped.tokenizer
|
self.tokenizer = wrapped.tokenizer
|
||||||
self.max_length = wrapped.max_length
|
|
||||||
self.token_mults = {}
|
self.token_mults = {}
|
||||||
|
|
||||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||||
@ -116,7 +114,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
maxlen = self.wrapped.max_length
|
maxlen = opts.max_prompt_tokens
|
||||||
|
|
||||||
if opts.enable_emphasis:
|
if opts.enable_emphasis:
|
||||||
parsed = prompt_parser.parse_prompt_attention(line)
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
@ -191,7 +189,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
def process_text_old(self, text):
|
def process_text_old(self, text):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
maxlen = self.wrapped.max_length
|
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||||
used_custom_terms = []
|
used_custom_terms = []
|
||||||
remade_batch_tokens = []
|
remade_batch_tokens = []
|
||||||
overflowing_words = []
|
overflowing_words = []
|
||||||
@ -268,8 +266,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
|
position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76]
|
||||||
|
position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens)
|
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
|
|
||||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
|
@ -118,8 +118,8 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
|||||||
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
# This was moved to webui.py with the other model "setup" calls.
|
|
||||||
# modules.sd_models.list_models()
|
vanilla_max_prompt_tokens = 77
|
||||||
|
|
||||||
|
|
||||||
def realesrgan_models_names():
|
def realesrgan_models_names():
|
||||||
@ -221,6 +221,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||||
|
"max_prompt_tokens": OptionInfo(vanilla_max_prompt_tokens, f"Max prompt token count. Two tokens are reserved for for start and end. Default is {vanilla_max_prompt_tokens}. Setting this to a different value will result in different pictures for same seed.", gr.Number, {"precision": 0}),
|
||||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user