Update model conversion util

This commit is contained in:
bmaltais 2022-12-05 11:13:41 -05:00
parent e8db30b9d1
commit 449a35368f
2 changed files with 158 additions and 22 deletions

View File

@ -6,7 +6,7 @@ import os
import torch
from diffusers import StableDiffusionPipeline
import model_util as model_util
import model_util
def convert(args):

View File

@ -1,9 +1,12 @@
# v1: split from train_db_fixed.py.
# v2: support safetensors
import math
import os
import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from safetensors.torch import load_file, save_file
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000
@ -34,7 +37,7 @@ V2_UNET_PARAMS_CONTEXT_DIM = 1024
# region StableDiffusion->Diffusersの変換コード
# convert_original_stable_diffusion_to_diffusers をコピーしているASL 2.0
# convert_original_stable_diffusion_to_diffusers をコピーして修正しているASL 2.0
def shave_segments(path, n_shave_prefix_segments=1):
@ -240,21 +243,21 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
for layer_id in range(num_output_blocks)
}
@ -329,14 +332,22 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if ["conv.weight", "conv.bias"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
# オリジナル:
# if ["conv.weight", "conv.bias"] in output_block_list.values():
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
# biasとweightの順番に依存しないようにするもっといいやり方がありそうだが
for l in output_block_list.values():
l.sort()
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
@ -617,7 +628,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
# region Diffusers->StableDiffusion の変換コード
# convert_diffusers_to_original_stable_diffusion をコピーしているASL 2.0
# convert_diffusers_to_original_stable_diffusion をコピーして修正しているASL 2.0
def conv_transformer_to_linear(checkpoint):
keys = list(checkpoint.keys())
@ -794,7 +805,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
# print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
@ -802,6 +813,11 @@ def convert_vae_state_dict(vae_state_dict):
# endregion
# region 自作のモデル読み書き
def is_safetensors(path):
return os.path.splitext(path)[1].lower() == '.safetensors'
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
@ -811,8 +827,16 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
]
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint["state_dict"]
if is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, "cpu")
else:
checkpoint = torch.load(ckpt_path, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
checkpoint = None
key_reps = []
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
@ -825,13 +849,12 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
state_dict[new_key] = state_dict[key]
del state_dict[key]
return checkpoint
return checkpoint, state_dict
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
state_dict = checkpoint["state_dict"]
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
if dtype is not None:
for k, v in state_dict.items():
if type(v) is torch.Tensor:
@ -962,9 +985,14 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
if ckpt_path is not None:
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
state_dict = checkpoint["state_dict"]
strict = True
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
if checkpoint is None: # safetensors または state_dictのckpt
checkpoint = {}
strict = False
else:
strict = True
if "state_dict" in state_dict:
del state_dict["state_dict"]
else:
# 新しく作る
checkpoint = {}
@ -1009,7 +1037,11 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
new_ckpt['epoch'] = epochs
new_ckpt['global_step'] = steps
torch.save(new_ckpt, output_file)
if is_safetensors(output_file):
# TODO Tensor以外のdictの値を削除したほうがいいか
save_file(state_dict, output_file)
else:
torch.save(new_ckpt, output_file)
return key_count
@ -1028,3 +1060,107 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
requires_safety_checker=None,
)
pipeline.save_pretrained(output_dir)
VAE_PREFIX = "first_stage_model."
def load_vae(vae_id, dtype):
print(f"load VAE: {vae_id}")
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
# Diffusers local/remote
try:
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
except EnvironmentError as e:
print(f"exception occurs in loading vae: {e}")
print("retry with subfolder='vae'")
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
return vae
# local
vae_config = create_vae_diffusers_config()
if vae_id.endswith(".bin"):
# SD 1.5 VAE on Huggingface
vae_sd = torch.load(vae_id, map_location="cpu")
converted_vae_checkpoint = vae_sd
else:
# StableDiffusion
vae_model = torch.load(vae_id, map_location="cpu")
vae_sd = vae_model['state_dict']
# vae only or full model
full_model = False
for vae_key in vae_sd:
if vae_key.startswith(VAE_PREFIX):
full_model = True
break
if not full_model:
sd = {}
for key, value in vae_sd.items():
sd[VAE_PREFIX + key] = value
vae_sd = sd
del sd
# Convert the VAE model.
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
def get_epoch_ckpt_name(use_safetensors, epoch):
return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt")
def get_last_ckpt_name(use_safetensors):
return f"last" + (".safetensors" if use_safetensors else ".ckpt")
# endregion
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
max_width, max_height = max_reso
max_area = (max_width // divisible) * (max_height // divisible)
resos = set()
size = int(math.sqrt(max_area)) * divisible
resos.add((size, size))
size = min_size
while size <= max_size:
width = size
height = min(max_size, (max_area // (width // divisible)) * divisible)
resos.add((width, height))
resos.add((height, width))
# # make additional resos
# if width >= height and width - divisible >= min_size:
# resos.add((width - divisible, height))
# resos.add((height, width - divisible))
# if height >= width and height - divisible >= min_size:
# resos.add((width, height - divisible))
# resos.add((height - divisible, width))
size += divisible
resos = list(resos)
resos.sort()
aspect_ratios = [w / h for w, h in resos]
return resos, aspect_ratios
if __name__ == '__main__':
resos, aspect_ratios = make_bucket_resolutions((512, 768))
print(len(resos))
print(resos)
print(aspect_ratios)
ars = set()
for ar in aspect_ratios:
if ar in ars:
print("error! duplicate ar:", ar)
ars.add(ar)