Implement PR #3309 but for embeddings.
This commit is contained in:
parent
c2dc9bfa89
commit
4875a6c217
@ -167,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|||||||
for i in range(num_vectors_per_token):
|
for i in range(num_vectors_per_token):
|
||||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
|
# Remove illegal characters from name.
|
||||||
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||||
if not overwrite_old:
|
if not overwrite_old:
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
@ -287,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||||
|
|
||||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
# Before saving, change name to match current checkpoint.
|
||||||
|
embedding.name = f'{embedding_name}-{embedding.step}'
|
||||||
|
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
|
||||||
embedding.save(last_saved_file)
|
embedding.save(last_saved_file)
|
||||||
embedding_yet_to_be_embedded = True
|
embedding_yet_to_be_embedded = True
|
||||||
|
|
||||||
@ -374,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
embedding.sd_checkpoint = checkpoint.hash
|
embedding.sd_checkpoint = checkpoint.hash
|
||||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||||
embedding.cached_checksum = None
|
embedding.cached_checksum = None
|
||||||
|
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
|
||||||
|
embedding.name = embedding_name
|
||||||
|
filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt')
|
||||||
embedding.save(filename)
|
embedding.save(filename)
|
||||||
|
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
Loading…
x
Reference in New Issue
Block a user