added token counter next to txt2img and img2img prompts
This commit is contained in:
parent
ca3e5519e8
commit
5034f7d759
13
javascript/helpers.js
Normal file
13
javascript/helpers.js
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
// helper functions
|
||||||
|
|
||||||
|
function debounce(func, wait_time) {
|
||||||
|
let timeout;
|
||||||
|
return function wrapped(...args) {
|
||||||
|
let call_function = () => {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
func(...args)
|
||||||
|
}
|
||||||
|
clearTimeout(timeout);
|
||||||
|
timeout = setTimeout(call_function, wait_time);
|
||||||
|
};
|
||||||
|
}
|
@ -183,4 +183,51 @@ onUiUpdate(function(){
|
|||||||
});
|
});
|
||||||
|
|
||||||
json_elem.parentElement.style.display="none"
|
json_elem.parentElement.style.display="none"
|
||||||
|
|
||||||
|
let debounce_time = 800
|
||||||
|
if (!txt2img_textarea) {
|
||||||
|
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea")
|
||||||
|
txt2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "txt2img"), debounce_time))
|
||||||
|
}
|
||||||
|
if (!img2img_textarea) {
|
||||||
|
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea")
|
||||||
|
img2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "img2img"), debounce_time))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
let txt2img_textarea, img2img_textarea = undefined;
|
||||||
|
function submit_prompt_text(source, e) {
|
||||||
|
let prompt_text;
|
||||||
|
if (source == "txt2img")
|
||||||
|
prompt_text = txt2img_textarea.value;
|
||||||
|
else if (source == "img2img")
|
||||||
|
prompt_text = img2img_textarea.value;
|
||||||
|
if (!prompt_text)
|
||||||
|
return;
|
||||||
|
params = {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-type": "application/json"
|
||||||
|
},
|
||||||
|
body: JSON.stringify({data:[prompt_text]})
|
||||||
|
}
|
||||||
|
fetch('http://127.0.0.1:7860/api/tokenize/', params)
|
||||||
|
.then((response) => response.json())
|
||||||
|
.then((data) => {
|
||||||
|
if (data?.data.length) {
|
||||||
|
let response_json = data.data[0]
|
||||||
|
if (elem = gradioApp().getElementById(source+"_token_counter")) {
|
||||||
|
if (response_json.token_count > response_json.max_length)
|
||||||
|
elem.classList.add("red");
|
||||||
|
else
|
||||||
|
elem.classList.remove("red");
|
||||||
|
elem.innerText = response_json.token_count + "/" + response_json.max_length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Error:', error);
|
||||||
|
});
|
||||||
|
}
|
@ -180,6 +180,7 @@ class StableDiffusionModelHijack:
|
|||||||
dir_mtime = None
|
dir_mtime = None
|
||||||
layers = None
|
layers = None
|
||||||
circular_enabled = False
|
circular_enabled = False
|
||||||
|
clip = None
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, dirname, model):
|
def load_textual_inversion_embeddings(self, dirname, model):
|
||||||
mt = os.path.getmtime(dirname)
|
mt = os.path.getmtime(dirname)
|
||||||
@ -242,6 +243,7 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention_v1:
|
if cmd_opts.opt_split_attention_v1:
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||||
@ -268,6 +270,11 @@ class StableDiffusionModelHijack:
|
|||||||
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
||||||
layer.padding_mode = 'circular' if enable else 'zeros'
|
layer.padding_mode = 'circular' if enable else 'zeros'
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
max_length = self.clip.max_length - 2
|
||||||
|
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||||
|
return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length}
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
@ -294,14 +301,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
if mult != 1.0:
|
if mult != 1.0:
|
||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
def forward(self, text):
|
def process_text(self, text):
|
||||||
self.hijack.fixes = []
|
|
||||||
self.hijack.comments = []
|
|
||||||
remade_batch_tokens = []
|
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
maxlen = self.wrapped.max_length
|
maxlen = self.wrapped.max_length
|
||||||
used_custom_terms = []
|
used_custom_terms = []
|
||||||
|
remade_batch_tokens = []
|
||||||
|
overflowing_words = []
|
||||||
|
hijack_comments = []
|
||||||
|
hijack_fixes = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
@ -353,9 +362,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
ovf = remade_tokens[maxlen - 2:]
|
ovf = remade_tokens[maxlen - 2:]
|
||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
token_count = len(remade_tokens)
|
||||||
|
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
@ -364,8 +372,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
remade_batch_tokens.append(remade_tokens)
|
||||||
self.hijack.fixes.append(fixes)
|
hijack_fixes.append(fixes)
|
||||||
batch_multipliers.append(multipliers)
|
batch_multipliers.append(multipliers)
|
||||||
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||||
|
self.hijack.fixes = hijack_fixes
|
||||||
|
self.hijack.comments = hijack_comments
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
@ -22,6 +22,7 @@ from modules.paths import script_path
|
|||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
|
from modules.sd_hijack import model_hijack
|
||||||
import modules.ldsr_model
|
import modules.ldsr_model
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
import modules.gfpgan_model
|
import modules.gfpgan_model
|
||||||
@ -337,11 +338,15 @@ def create_toprow(is_img2img):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2)
|
prompt = gr.Textbox(label="Prompt", elem_id=id_part+"_prompt", show_label=False, placeholder="Prompt", lines=2)
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="roll_col"):
|
with gr.Column(scale=1, elem_id="roll_col"):
|
||||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||||
|
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||||
|
token_output = gr.JSON(visible=False)
|
||||||
|
if is_img2img: # only define the api function ONCE
|
||||||
|
token_counter.change(fn=model_hijack.tokenize, api_name="tokenize", inputs=[token_counter], outputs=[token_output])
|
||||||
|
|
||||||
with gr.Column(scale=10, elem_id="style_pos_col"):
|
with gr.Column(scale=10, elem_id="style_pos_col"):
|
||||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user