add textual inversion hashes to infotext
This commit is contained in:
parent
127635409a
commit
2b1bae0d75
@ -732,9 +732,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
p.setup_conds()
|
p.setup_conds()
|
||||||
|
|
||||||
if len(model_hijack.comments) > 0:
|
for comment in model_hijack.comments:
|
||||||
for comment in model_hijack.comments:
|
comments[comment] = 1
|
||||||
comments[comment] = 1
|
|
||||||
|
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
|
|
||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
@ -147,7 +147,6 @@ def undo_weighted_forward(sd_model):
|
|||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
fixes = None
|
fixes = None
|
||||||
comments = []
|
|
||||||
layers = None
|
layers = None
|
||||||
circular_enabled = False
|
circular_enabled = False
|
||||||
clip = None
|
clip = None
|
||||||
@ -156,6 +155,9 @@ class StableDiffusionModelHijack:
|
|||||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.extra_generation_params = {}
|
||||||
|
self.comments = []
|
||||||
|
|
||||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
def apply_optimizations(self, option=None):
|
def apply_optimizations(self, option=None):
|
||||||
@ -236,6 +238,7 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
def clear_comments(self):
|
def clear_comments(self):
|
||||||
self.comments = []
|
self.comments = []
|
||||||
|
self.extra_generation_params = {}
|
||||||
|
|
||||||
def get_prompt_lengths(self, text):
|
def get_prompt_lengths(self, text):
|
||||||
if self.clip is None:
|
if self.clip is None:
|
||||||
|
@ -229,9 +229,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
z = self.process_tokens(tokens, multipliers)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
zs.append(z)
|
zs.append(z)
|
||||||
|
|
||||||
if len(used_embeddings) > 0:
|
if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
|
||||||
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
|
hashes = []
|
||||||
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
|
for name, embedding in used_embeddings.items():
|
||||||
|
shorthash = embedding.shorthash
|
||||||
|
if not shorthash:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name.replace(":", "").replace(",", "")
|
||||||
|
hashes.append(f"{name}: {shorthash}")
|
||||||
|
|
||||||
|
if hashes:
|
||||||
|
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||||
|
|
||||||
return torch.hstack(zs)
|
return torch.hstack(zs)
|
||||||
|
|
||||||
|
@ -472,6 +472,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
||||||
|
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
|
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
@ -49,6 +49,8 @@ class Embedding:
|
|||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
self.optimizer_state_dict = None
|
self.optimizer_state_dict = None
|
||||||
self.filename = None
|
self.filename = None
|
||||||
|
self.hash = None
|
||||||
|
self.shorthash = None
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
embedding_data = {
|
embedding_data = {
|
||||||
@ -82,6 +84,10 @@ class Embedding:
|
|||||||
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
||||||
return self.cached_checksum
|
return self.cached_checksum
|
||||||
|
|
||||||
|
def set_hash(self, v):
|
||||||
|
self.hash = v
|
||||||
|
self.shorthash = self.hash[0:12]
|
||||||
|
|
||||||
|
|
||||||
class DirWithTextualInversionEmbeddings:
|
class DirWithTextualInversionEmbeddings:
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
@ -199,6 +205,7 @@ class EmbeddingDatabase:
|
|||||||
embedding.vectors = vec.shape[0]
|
embedding.vectors = vec.shape[0]
|
||||||
embedding.shape = vec.shape[-1]
|
embedding.shape = vec.shape[-1]
|
||||||
embedding.filename = path
|
embedding.filename = path
|
||||||
|
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
||||||
|
|
||||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
|
Loading…
Reference in New Issue
Block a user