allow overwrite old embedding
This commit is contained in:
parent
166be3919b
commit
0087079c2d
@ -153,7 +153,7 @@ class EmbeddingDatabase:
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||||
|
|
||||||
@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|||||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
if not overwrite_old:
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
embedding = Embedding(vec, name)
|
embedding = Embedding(vec, name)
|
||||||
embedding.step = 0
|
embedding.step = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user