From 449a35368f681a9bf62b67eba6ea38a395df96c4 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 5 Dec 2022 11:13:41 -0500 Subject: [PATCH] Update model conversion util --- tools/convert_diffusers20_original_sd.py | 2 +- tools/model_util.py | 178 ++++++++++++++++++++--- 2 files changed, 158 insertions(+), 22 deletions(-) diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index 3f583bf..4809455 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -6,7 +6,7 @@ import os import torch from diffusers import StableDiffusionPipeline -import model_util as model_util +import model_util def convert(args): diff --git a/tools/model_util.py b/tools/model_util.py index d74b0c5..74650bf 100644 --- a/tools/model_util.py +++ b/tools/model_util.py @@ -1,9 +1,12 @@ # v1: split from train_db_fixed.py. +# v2: support safetensors +import math +import os import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel - +from safetensors.torch import load_file, save_file # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 @@ -34,7 +37,7 @@ V2_UNET_PARAMS_CONTEXT_DIM = 1024 # region StableDiffusion->Diffusersの変換コード -# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0) +# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0) def shave_segments(path, n_shave_prefix_segments=1): @@ -240,21 +243,21 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) } @@ -329,14 +332,22 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) - if ["conv.weight", "conv.bias"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.bias" ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] # Clear attentions as they have been attributed above. if len(attentions) == 2: @@ -617,7 +628,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): # region Diffusers->StableDiffusion の変換コード -# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0) +# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) def conv_transformer_to_linear(checkpoint): keys = list(checkpoint.keys()) @@ -794,7 +805,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - print(f"Reshaping {k} for SD format") + # print(f"Reshaping {k} for SD format") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -802,6 +813,11 @@ def convert_vae_state_dict(vae_state_dict): # endregion +# region 自作のモデル読み書き + +def is_safetensors(path): + return os.path.splitext(path)[1].lower() == '.safetensors' + def load_checkpoint_with_text_encoder_conversion(ckpt_path): # text encoderの格納形式が違うモデルに対応する ('text_model'がない) @@ -811,8 +827,16 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') ] - checkpoint = torch.load(ckpt_path, map_location="cpu") - state_dict = checkpoint["state_dict"] + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path, "cpu") + else: + checkpoint = torch.load(ckpt_path, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None key_reps = [] for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: @@ -825,13 +849,12 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): state_dict[new_key] = state_dict[key] del state_dict[key] - return checkpoint + return checkpoint, state_dict # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): - checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) - state_dict = checkpoint["state_dict"] + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) if dtype is not None: for k, v in state_dict.items(): if type(v) is torch.Tensor: @@ -962,9 +985,14 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): if ckpt_path is not None: # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) - state_dict = checkpoint["state_dict"] - strict = True + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] else: # 新しく作る checkpoint = {} @@ -1009,7 +1037,11 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p new_ckpt['epoch'] = epochs new_ckpt['global_step'] = steps - torch.save(new_ckpt, output_file) + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) return key_count @@ -1028,3 +1060,107 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod requires_safety_checker=None, ) pipeline.save_pretrained(output_dir) + + +VAE_PREFIX = "first_stage_model." + + +def load_vae(vae_id, dtype): + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + vae_sd = torch.load(vae_id, map_location="cpu") + converted_vae_checkpoint = vae_sd + else: + # StableDiffusion + vae_model = torch.load(vae_id, map_location="cpu") + vae_sd = vae_model['state_dict'] + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +def get_epoch_ckpt_name(use_safetensors, epoch): + return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") + + +def get_last_ckpt_name(use_safetensors): + return f"last" + (".safetensors" if use_safetensors else ".ckpt") + +# endregion + + +def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) + + resos = set() + + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) + + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) + + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) + + size += divisible + + resos = list(resos) + resos.sort() + + aspect_ratios = [w / h for w, h in resos] + return resos, aspect_ratios + + +if __name__ == '__main__': + resos, aspect_ratios = make_bucket_resolutions((512, 768)) + print(len(resos)) + print(resos) + print(aspect_ratios) + + ars = set() + for ar in aspect_ratios: + if ar in ars: + print("error! duplicate ar:", ar) + ars.add(ar)