From 39c0c295fbaf179c9456b019f775dfbb983a27a7 Mon Sep 17 00:00:00 2001 From: Bernard Maltais Date: Sat, 5 Nov 2022 16:18:07 -0400 Subject: [PATCH] Make save ckpt in fp16 optional --- train_db_fixed_v6-ber.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/train_db_fixed_v6-ber.py b/train_db_fixed_v6-ber.py index a693099..af904f9 100644 --- a/train_db_fixed_v6-ber.py +++ b/train_db_fixed_v6-ber.py @@ -903,16 +903,20 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path) for k, v in unet_state_dict.items(): key = "model.diffusion_model." + k assert key in state_dict, f"Illegal key in save SD: {key}" - # state_dict[key] = v - state_dict[key] = v.half() # save to fp16 + if args.save_half: + state_dict[key] = v.half() # save to fp16 + else: + state_dict[key] = v # Convert the text encoder model text_enc_dict = text_encoder.state_dict() # 変換不要 for k, v in text_enc_dict.items(): key = "cond_stage_model.transformer." + k assert key in state_dict, f"Illegal key in save SD: {key}" - # state_dict[key] = v - state_dict[key] = v.half() # save to fp16 + if args.save_half: + state_dict[key] = v.half() # save to fp16 + else: + state_dict[key] = v # Put together new checkpoint state_dict = {"state_dict": state_dict}