changed embedding accepted shape detection to use existing code and support the new alt-diffusion model, and reformatted messages a bit #6149
This commit is contained in:
parent
c24a314c5e
commit
bdbe09827b
@ -80,23 +80,8 @@ class EmbeddingDatabase:
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def get_expected_shape(self):
|
def get_expected_shape(self):
|
||||||
expected_shape = -1 # initialize with unknown
|
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||||
idx = torch.tensor(0).to(shared.device)
|
return vec.shape[1]
|
||||||
if expected_shape == -1:
|
|
||||||
try: # matches sd15 signature
|
|
||||||
first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
|
|
||||||
expected_shape = first_embedding.shape[0]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if expected_shape == -1:
|
|
||||||
try: # matches sd20 signature
|
|
||||||
first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
|
|
||||||
expected_shape = first_embedding.shape[0]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if expected_shape == -1:
|
|
||||||
print('Could not determine expected embeddings shape from model')
|
|
||||||
return expected_shape
|
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, force_reload = False):
|
def load_textual_inversion_embeddings(self, force_reload = False):
|
||||||
mt = os.path.getmtime(self.embeddings_dir)
|
mt = os.path.getmtime(self.embeddings_dir)
|
||||||
@ -112,8 +97,6 @@ class EmbeddingDatabase:
|
|||||||
def process_file(path, filename):
|
def process_file(path, filename):
|
||||||
name = os.path.splitext(filename)[0]
|
name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
data = []
|
|
||||||
|
|
||||||
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||||
embed_image = Image.open(path)
|
embed_image = Image.open(path)
|
||||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||||
@ -150,11 +133,10 @@ class EmbeddingDatabase:
|
|||||||
embedding.vectors = vec.shape[0]
|
embedding.vectors = vec.shape[0]
|
||||||
embedding.shape = vec.shape[-1]
|
embedding.shape = vec.shape[-1]
|
||||||
|
|
||||||
if (self.expected_shape == -1) or (self.expected_shape == embedding.shape):
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
else:
|
else:
|
||||||
self.skipped_embeddings.append(name)
|
self.skipped_embeddings.append(name)
|
||||||
# print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
|
|
||||||
|
|
||||||
for fn in os.listdir(self.embeddings_dir):
|
for fn in os.listdir(self.embeddings_dir):
|
||||||
try:
|
try:
|
||||||
@ -169,9 +151,9 @@ class EmbeddingDatabase:
|
|||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys())))
|
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||||
if (len(self.skipped_embeddings) > 0):
|
if len(self.skipped_embeddings) > 0:
|
||||||
print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings)))
|
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
|
||||||
|
|
||||||
def find_embedding_at_position(self, tokens, offset):
|
def find_embedding_at_position(self, tokens, offset):
|
||||||
token = tokens[offset]
|
token = tokens[offset]
|
||||||
|
Loading…
Reference in New Issue
Block a user