allow overwrite old embedding

This commit is contained in:
DepFA 2022-10-20 00:10:59 +01:00 committed by GitHub
parent 166be3919b
commit 0087079c2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'): def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name) embedding = Embedding(vec, name)
embedding.step = 0 embedding.step = 0