Save Optimizer next to TI embedding
Also add check to load only .PT and .BIN files as embeddings. (since we add .optim files in the same directory)
This commit is contained in:
parent
c0ee148870
commit
bddebe09ed
@ -355,7 +355,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
|
@ -28,6 +28,7 @@ class Embedding:
|
||||
self.cached_checksum = None
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.optimizer_state_dict = None
|
||||
|
||||
def save(self, filename):
|
||||
embedding_data = {
|
||||
@ -41,6 +42,13 @@ class Embedding:
|
||||
|
||||
torch.save(embedding_data, filename)
|
||||
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
|
||||
optimizer_saved_dict = {
|
||||
'hash': self.checksum(),
|
||||
'optimizer_state_dict': self.optimizer_state_dict,
|
||||
}
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
def checksum(self):
|
||||
if self.cached_checksum is not None:
|
||||
return self.cached_checksum
|
||||
@ -95,9 +103,10 @@ class EmbeddingDatabase:
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
def process_file(path, filename):
|
||||
name = os.path.splitext(filename)[0]
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
|
||||
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
@ -105,8 +114,10 @@ class EmbeddingDatabase:
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
else:
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||
if shared.opts.save_optimizer_state:
|
||||
optimizer_state_dict = None
|
||||
if os.path.exists(filename + '.optim'):
|
||||
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
|
||||
if optimizer_state_dict is not None:
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
else:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
batch_size = ds.batch_size
|
||||
@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
#if shared.opts.save_optimizer_state:
|
||||
#embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||
@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
pass
|
||||
@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
|
||||
return embedding, filename
|
||||
|
||||
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
old_embedding_name = embedding.name
|
||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||
@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
|
||||
if remove_cached_checksum:
|
||||
embedding.cached_checksum = None
|
||||
embedding.name = embedding_name
|
||||
embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
embedding.save(filename)
|
||||
except:
|
||||
embedding.sd_checkpoint = old_sd_checkpoint
|
||||
|
Loading…
Reference in New Issue
Block a user