make it possible for extensions/scripts to add their own embedding directories
This commit is contained in:
parent
a0c87f1fdf
commit
085427de0e
@ -83,10 +83,12 @@ class StableDiffusionModelHijack:
|
||||
clip = None
|
||||
optimization_method = None
|
||||
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
|
||||
def __init__(self):
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
|
||||
def hijack(self, m):
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
@ -117,7 +119,6 @@ class StableDiffusionModelHijack:
|
||||
self.layers = flatten(m)
|
||||
|
||||
def undo_hijack(self, m):
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
|
@ -66,17 +66,41 @@ class Embedding:
|
||||
return self.cached_checksum
|
||||
|
||||
|
||||
class DirWithTextualInversionEmbeddings:
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.mtime = None
|
||||
|
||||
def has_changed(self):
|
||||
if not os.path.isdir(self.path):
|
||||
return False
|
||||
|
||||
mt = os.path.getmtime(self.path)
|
||||
if self.mtime is None or mt > self.mtime:
|
||||
return True
|
||||
|
||||
def update(self):
|
||||
if not os.path.isdir(self.path):
|
||||
return
|
||||
|
||||
self.mtime = os.path.getmtime(self.path)
|
||||
|
||||
|
||||
class EmbeddingDatabase:
|
||||
def __init__(self, embeddings_dir):
|
||||
def __init__(self):
|
||||
self.ids_lookup = {}
|
||||
self.word_embeddings = {}
|
||||
self.skipped_embeddings = {}
|
||||
self.dir_mtime = None
|
||||
self.embeddings_dir = embeddings_dir
|
||||
self.expected_shape = -1
|
||||
self.embedding_dirs = {}
|
||||
|
||||
def add_embedding_dir(self, path):
|
||||
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||
|
||||
def clear_embedding_dirs(self):
|
||||
self.embedding_dirs.clear()
|
||||
|
||||
def register_embedding(self, embedding, model):
|
||||
|
||||
self.word_embeddings[embedding.name] = embedding
|
||||
|
||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||
@ -93,69 +117,62 @@ class EmbeddingDatabase:
|
||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
return vec.shape[1]
|
||||
|
||||
def load_textual_inversion_embeddings(self, force_reload = False):
|
||||
mt = os.path.getmtime(self.embeddings_dir)
|
||||
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||
return
|
||||
def load_from_file(self, path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
|
||||
self.dir_mtime = mt
|
||||
self.ids_lookup.clear()
|
||||
self.word_embeddings.clear()
|
||||
self.skipped_embeddings.clear()
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
def process_file(path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
_, second_ext = os.path.splitext(name)
|
||||
if second_ext.upper() == '.PREVIEW':
|
||||
return
|
||||
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
else:
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
_, second_ext = os.path.splitext(name)
|
||||
if second_ext.upper() == '.PREVIEW':
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
else:
|
||||
return
|
||||
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
else:
|
||||
self.skipped_embeddings[name] = embedding
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
for root, dirs, fns in os.walk(self.embeddings_dir):
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
else:
|
||||
self.skipped_embeddings[name] = embedding
|
||||
|
||||
def load_from_dir(self, embdir):
|
||||
if not os.path.isdir(embdir.path):
|
||||
return
|
||||
|
||||
for root, dirs, fns in os.walk(embdir.path):
|
||||
for fn in fns:
|
||||
try:
|
||||
fullfn = os.path.join(root, fn)
|
||||
@ -163,12 +180,32 @@ class EmbeddingDatabase:
|
||||
if os.stat(fullfn).st_size == 0:
|
||||
continue
|
||||
|
||||
process_file(fullfn, fn)
|
||||
self.load_from_file(fullfn, fn)
|
||||
except Exception:
|
||||
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
continue
|
||||
|
||||
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||
if not force_reload:
|
||||
need_reload = False
|
||||
for path, embdir in self.embedding_dirs.items():
|
||||
if embdir.has_changed():
|
||||
need_reload = True
|
||||
break
|
||||
|
||||
if not need_reload:
|
||||
return
|
||||
|
||||
self.ids_lookup.clear()
|
||||
self.word_embeddings.clear()
|
||||
self.skipped_embeddings.clear()
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
for path, embdir in self.embedding_dirs.items():
|
||||
self.load_from_dir(embdir)
|
||||
embdir.update()
|
||||
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
||||
assert steps, "Max steps is empty or 0"
|
||||
assert isinstance(steps, int), "Max steps must be integer"
|
||||
assert steps > 0 , "Max steps must be positive"
|
||||
assert steps > 0, "Max steps must be positive"
|
||||
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
||||
assert save_model_every >= 0 , "Save {name} must be positive or 0"
|
||||
assert save_model_every >= 0, "Save {name} must be positive or 0"
|
||||
assert isinstance(create_image_every, int), "Create image must be integer"
|
||||
assert create_image_every >= 0 , "Create image must be positive or 0"
|
||||
assert create_image_every >= 0, "Create image must be positive or 0"
|
||||
if save_model_every or create_image_every:
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
|
Loading…
Reference in New Issue
Block a user