Update model conversion util

This commit is contained in:
bmaltais 2022-12-05 11:13:41 -05:00
parent e8db30b9d1
commit 449a35368f
2 changed files with 158 additions and 22 deletions

View File

@ -6,7 +6,7 @@ import os
import torch import torch
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
import model_util as model_util import model_util
def convert(args): def convert(args):

View File

@ -1,9 +1,12 @@
# v1: split from train_db_fixed.py. # v1: split from train_db_fixed.py.
# v2: support safetensors
import math
import os
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from safetensors.torch import load_file, save_file
# DiffUsers版StableDiffusionのモデルパラメータ # DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000 NUM_TRAIN_TIMESTEPS = 1000
@ -34,7 +37,7 @@ V2_UNET_PARAMS_CONTEXT_DIM = 1024
# region StableDiffusion->Diffusersの変換コード # region StableDiffusion->Diffusersの変換コード
# convert_original_stable_diffusion_to_diffusers をコピーしているASL 2.0 # convert_original_stable_diffusion_to_diffusers をコピーして修正しているASL 2.0
def shave_segments(path, n_shave_prefix_segments=1): def shave_segments(path, n_shave_prefix_segments=1):
@ -240,21 +243,21 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
# Retrieves the keys for the input blocks only # Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = { input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
for layer_id in range(num_input_blocks) for layer_id in range(num_input_blocks)
} }
# Retrieves the keys for the middle blocks only # Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = { middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
for layer_id in range(num_middle_blocks) for layer_id in range(num_middle_blocks)
} }
# Retrieves the keys for the output blocks only # Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = { output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
for layer_id in range(num_output_blocks) for layer_id in range(num_output_blocks)
} }
@ -329,14 +332,22 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
) )
if ["conv.weight", "conv.bias"] in output_block_list.values(): # オリジナル:
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) # if ["conv.weight", "conv.bias"] in output_block_list.values():
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
f"output_blocks.{i}.{index}.conv.weight"
] # biasとweightの順番に依存しないようにするもっといいやり方がありそうだが
for l in output_block_list.values():
l.sort()
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias" f"output_blocks.{i}.{index}.conv.bias"
] ]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
# Clear attentions as they have been attributed above. # Clear attentions as they have been attributed above.
if len(attentions) == 2: if len(attentions) == 2:
@ -617,7 +628,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
# region Diffusers->StableDiffusion の変換コード # region Diffusers->StableDiffusion の変換コード
# convert_diffusers_to_original_stable_diffusion をコピーしているASL 2.0 # convert_diffusers_to_original_stable_diffusion をコピーして修正しているASL 2.0
def conv_transformer_to_linear(checkpoint): def conv_transformer_to_linear(checkpoint):
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
@ -794,7 +805,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
for weight_name in weights_to_convert: for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k: if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format") # print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v) new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict return new_state_dict
@ -802,6 +813,11 @@ def convert_vae_state_dict(vae_state_dict):
# endregion # endregion
# region 自作のモデル読み書き
def is_safetensors(path):
return os.path.splitext(path)[1].lower() == '.safetensors'
def load_checkpoint_with_text_encoder_conversion(ckpt_path): def load_checkpoint_with_text_encoder_conversion(ckpt_path):
# text encoderの格納形式が違うモデルに対応する ('text_model'がない) # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
@ -811,8 +827,16 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
] ]
checkpoint = torch.load(ckpt_path, map_location="cpu") if is_safetensors(ckpt_path):
state_dict = checkpoint["state_dict"] checkpoint = None
state_dict = load_file(ckpt_path, "cpu")
else:
checkpoint = torch.load(ckpt_path, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
checkpoint = None
key_reps = [] key_reps = []
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
@ -825,13 +849,12 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
state_dict[new_key] = state_dict[key] state_dict[new_key] = state_dict[key]
del state_dict[key] del state_dict[key]
return checkpoint return checkpoint, state_dict
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
state_dict = checkpoint["state_dict"]
if dtype is not None: if dtype is not None:
for k, v in state_dict.items(): for k, v in state_dict.items():
if type(v) is torch.Tensor: if type(v) is torch.Tensor:
@ -962,9 +985,14 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
if ckpt_path is not None: if ckpt_path is not None:
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
state_dict = checkpoint["state_dict"] if checkpoint is None: # safetensors または state_dictのckpt
strict = True checkpoint = {}
strict = False
else:
strict = True
if "state_dict" in state_dict:
del state_dict["state_dict"]
else: else:
# 新しく作る # 新しく作る
checkpoint = {} checkpoint = {}
@ -1009,7 +1037,11 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
new_ckpt['epoch'] = epochs new_ckpt['epoch'] = epochs
new_ckpt['global_step'] = steps new_ckpt['global_step'] = steps
torch.save(new_ckpt, output_file) if is_safetensors(output_file):
# TODO Tensor以外のdictの値を削除したほうがいいか
save_file(state_dict, output_file)
else:
torch.save(new_ckpt, output_file)
return key_count return key_count
@ -1028,3 +1060,107 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
requires_safety_checker=None, requires_safety_checker=None,
) )
pipeline.save_pretrained(output_dir) pipeline.save_pretrained(output_dir)
VAE_PREFIX = "first_stage_model."
def load_vae(vae_id, dtype):
print(f"load VAE: {vae_id}")
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
# Diffusers local/remote
try:
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
except EnvironmentError as e:
print(f"exception occurs in loading vae: {e}")
print("retry with subfolder='vae'")
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
return vae
# local
vae_config = create_vae_diffusers_config()
if vae_id.endswith(".bin"):
# SD 1.5 VAE on Huggingface
vae_sd = torch.load(vae_id, map_location="cpu")
converted_vae_checkpoint = vae_sd
else:
# StableDiffusion
vae_model = torch.load(vae_id, map_location="cpu")
vae_sd = vae_model['state_dict']
# vae only or full model
full_model = False
for vae_key in vae_sd:
if vae_key.startswith(VAE_PREFIX):
full_model = True
break
if not full_model:
sd = {}
for key, value in vae_sd.items():
sd[VAE_PREFIX + key] = value
vae_sd = sd
del sd
# Convert the VAE model.
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
def get_epoch_ckpt_name(use_safetensors, epoch):
return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt")
def get_last_ckpt_name(use_safetensors):
return f"last" + (".safetensors" if use_safetensors else ".ckpt")
# endregion
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
max_width, max_height = max_reso
max_area = (max_width // divisible) * (max_height // divisible)
resos = set()
size = int(math.sqrt(max_area)) * divisible
resos.add((size, size))
size = min_size
while size <= max_size:
width = size
height = min(max_size, (max_area // (width // divisible)) * divisible)
resos.add((width, height))
resos.add((height, width))
# # make additional resos
# if width >= height and width - divisible >= min_size:
# resos.add((width - divisible, height))
# resos.add((height, width - divisible))
# if height >= width and height - divisible >= min_size:
# resos.add((width, height - divisible))
# resos.add((height - divisible, width))
size += divisible
resos = list(resos)
resos.sort()
aspect_ratios = [w / h for w, h in resos]
return resos, aspect_ratios
if __name__ == '__main__':
resos, aspect_ratios = make_bucket_resolutions((512, 768))
print(len(resos))
print(resos)
print(aspect_ratios)
ars = set()
for ar in aspect_ratios:
if ar in ars:
print("error! duplicate ar:", ar)
ars.add(ar)