Separate .optim file from model
This commit is contained in:
parent
7ea5956ad5
commit
0b143c1163
@ -161,6 +161,7 @@ class Hypernetwork:
|
|||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
optimizer_saved_dict = {}
|
||||||
|
|
||||||
for k, v in self.layers.items():
|
for k, v in self.layers.items():
|
||||||
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
||||||
@ -175,9 +176,10 @@ class Hypernetwork:
|
|||||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
if self.optimizer_name is not None:
|
if self.optimizer_name is not None:
|
||||||
state_dict['optimizer_name'] = self.optimizer_name
|
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
||||||
if self.optimizer_state_dict:
|
if self.optimizer_state_dict:
|
||||||
state_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||||
|
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||||
|
|
||||||
torch.save(state_dict, filename)
|
torch.save(state_dict, filename)
|
||||||
|
|
||||||
@ -198,9 +200,11 @@ class Hypernetwork:
|
|||||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||||
self.use_dropout = state_dict.get('use_dropout', False)
|
self.use_dropout = state_dict.get('use_dropout', False)
|
||||||
print(f"Dropout usage is set to {self.use_dropout}")
|
print(f"Dropout usage is set to {self.use_dropout}")
|
||||||
self.optimizer_name = state_dict.get('optimizer_name', 'AdamW')
|
|
||||||
|
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||||
|
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||||
print(f"Optimizer name is {self.optimizer_name}")
|
print(f"Optimizer name is {self.optimizer_name}")
|
||||||
self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None)
|
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||||
if self.optimizer_state_dict:
|
if self.optimizer_state_dict:
|
||||||
print("Loaded existing optimizer from checkpoint")
|
print("Loaded existing optimizer from checkpoint")
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user