add checkpoint info to saved embeddings
This commit is contained in:
parent
71fe7fa49f
commit
4ec4af6e0b
@ -7,7 +7,7 @@ import tqdm
|
|||||||
import html
|
import html
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
|
||||||
|
|
||||||
@ -17,6 +17,8 @@ class Embedding:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.step = step
|
self.step = step
|
||||||
self.cached_checksum = None
|
self.cached_checksum = None
|
||||||
|
self.sd_checkpoint = None
|
||||||
|
self.sd_checkpoint_name = None
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
embedding_data = {
|
embedding_data = {
|
||||||
@ -24,6 +26,8 @@ class Embedding:
|
|||||||
"string_to_param": {"*": self.vec},
|
"string_to_param": {"*": self.vec},
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"step": self.step,
|
"step": self.step,
|
||||||
|
"sd_checkpoint": self.sd_checkpoint,
|
||||||
|
"sd_checkpoint_name": self.sd_checkpoint_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
torch.save(embedding_data, filename)
|
torch.save(embedding_data, filename)
|
||||||
@ -41,6 +45,7 @@ class Embedding:
|
|||||||
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
||||||
return self.cached_checksum
|
return self.cached_checksum
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingDatabase:
|
class EmbeddingDatabase:
|
||||||
def __init__(self, embeddings_dir):
|
def __init__(self, embeddings_dir):
|
||||||
self.ids_lookup = {}
|
self.ids_lookup = {}
|
||||||
@ -96,6 +101,8 @@ class EmbeddingDatabase:
|
|||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
embedding = Embedding(vec, name)
|
embedding = Embedding(vec, name)
|
||||||
embedding.step = data.get('step', None)
|
embedding.step = data.get('step', None)
|
||||||
|
embedding.sd_checkpoint = data.get('hash', None)
|
||||||
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
|
|
||||||
for fn in os.listdir(self.embeddings_dir):
|
for fn in os.listdir(self.embeddings_dir):
|
||||||
@ -249,6 +256,10 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
embedding.sd_checkpoint = checkpoint.hash
|
||||||
|
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||||
embedding.cached_checksum = None
|
embedding.cached_checksum = None
|
||||||
embedding.save(filename)
|
embedding.save(filename)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user