support for sd-concepts as alternatives for textual inversion #151

This commit is contained in:
AUTOMATIC 2022-09-08 15:36:50 +03:00
parent f5001246e2
commit 62ce77e245
2 changed files with 17 additions and 6 deletions

3
.gitignore vendored
View File

@ -9,4 +9,5 @@ __pycache__
/outputs /outputs
/config.json /config.json
/log /log
webui.settings.bat /webui.settings.bat
/embeddings

View File

@ -73,11 +73,21 @@ class StableDiffusionModelHijack:
name = os.path.splitext(filename)[0] name = os.path.splitext(filename)[0]
data = torch.load(path) data = torch.load(path)
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param'] param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'): if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it' assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1] emb = next(iter(param_dict.items()))[1]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
self.word_embeddings[name] = emb.detach() self.word_embeddings[name] = emb.detach()
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}' self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}'