Make save ckpt in fp16 optional

This commit is contained in:
Bernard Maltais 2022-11-05 16:18:07 -04:00
parent a0f1832154
commit 39c0c295fb

View File

@ -903,16 +903,20 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path)
for k, v in unet_state_dict.items():
key = "model.diffusion_model." + k
assert key in state_dict, f"Illegal key in save SD: {key}"
# state_dict[key] = v
state_dict[key] = v.half() # save to fp16
if args.save_half:
state_dict[key] = v.half() # save to fp16
else:
state_dict[key] = v
# Convert the text encoder model
text_enc_dict = text_encoder.state_dict() # 変換不要
for k, v in text_enc_dict.items():
key = "cond_stage_model.transformer." + k
assert key in state_dict, f"Illegal key in save SD: {key}"
# state_dict[key] = v
state_dict[key] = v.half() # save to fp16
if args.save_half:
state_dict[key] = v.half() # save to fp16
else:
state_dict[key] = v
# Put together new checkpoint
state_dict = {"state_dict": state_dict}