more careful loading of model weights (eliminates some issues with checkpoints that have weird cond_stage_model layer names)
This commit is contained in:
parent
c1093b8051
commit
10aca1ca3e
@ -122,11 +122,33 @@ def select_checkpoint():
|
|||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
|
chckpoint_dict_replacements = {
|
||||||
|
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||||
|
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||||
|
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def transform_checkpoint_dict_key(k):
|
||||||
|
for text, replacement in chckpoint_dict_replacements.items():
|
||||||
|
if k.startswith(text):
|
||||||
|
k = replacement + k[len(text):]
|
||||||
|
|
||||||
|
return k
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict_from_checkpoint(pl_sd):
|
def get_state_dict_from_checkpoint(pl_sd):
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
return pl_sd["state_dict"]
|
pl_sd = pl_sd["state_dict"]
|
||||||
|
|
||||||
return pl_sd
|
sd = {}
|
||||||
|
for k, v in pl_sd.items():
|
||||||
|
new_key = transform_checkpoint_dict_key(k)
|
||||||
|
|
||||||
|
if new_key is not None:
|
||||||
|
sd[new_key] = v
|
||||||
|
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info):
|
def load_model_weights(model, checkpoint_info):
|
||||||
@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info):
|
|||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
|
||||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||||
model.load_state_dict(sd, strict=False)
|
missing, extra = model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
Loading…
Reference in New Issue
Block a user