Make save ckpt in fp16 optional
This commit is contained in:
parent
a0f1832154
commit
39c0c295fb
@ -903,16 +903,20 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path)
|
||||
for k, v in unet_state_dict.items():
|
||||
key = "model.diffusion_model." + k
|
||||
assert key in state_dict, f"Illegal key in save SD: {key}"
|
||||
# state_dict[key] = v
|
||||
state_dict[key] = v.half() # save to fp16
|
||||
if args.save_half:
|
||||
state_dict[key] = v.half() # save to fp16
|
||||
else:
|
||||
state_dict[key] = v
|
||||
|
||||
# Convert the text encoder model
|
||||
text_enc_dict = text_encoder.state_dict() # 変換不要
|
||||
for k, v in text_enc_dict.items():
|
||||
key = "cond_stage_model.transformer." + k
|
||||
assert key in state_dict, f"Illegal key in save SD: {key}"
|
||||
# state_dict[key] = v
|
||||
state_dict[key] = v.half() # save to fp16
|
||||
if args.save_half:
|
||||
state_dict[key] = v.half() # save to fp16
|
||||
else:
|
||||
state_dict[key] = v
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {"state_dict": state_dict}
|
||||
|
Loading…
Reference in New Issue
Block a user