Update model conversion util
This commit is contained in:
parent
e8db30b9d1
commit
449a35368f
@ -6,7 +6,7 @@ import os
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
import model_util as model_util
|
||||
import model_util
|
||||
|
||||
|
||||
def convert(args):
|
||||
|
@ -1,9 +1,12 @@
|
||||
# v1: split from train_db_fixed.py.
|
||||
# v2: support safetensors
|
||||
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||
NUM_TRAIN_TIMESTEPS = 1000
|
||||
@ -34,7 +37,7 @@ V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
||||
|
||||
|
||||
# region StableDiffusion->Diffusersの変換コード
|
||||
# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
|
||||
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
@ -240,21 +243,21 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
||||
for layer_id in range(num_middle_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||
output_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
||||
for layer_id in range(num_output_blocks)
|
||||
}
|
||||
|
||||
@ -329,14 +332,22 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
# オリジナル:
|
||||
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||
|
||||
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
||||
for l in output_block_list.values():
|
||||
l.sort()
|
||||
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.bias"
|
||||
]
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
|
||||
# Clear attentions as they have been attributed above.
|
||||
if len(attentions) == 2:
|
||||
@ -617,7 +628,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
|
||||
|
||||
# region Diffusers->StableDiffusion の変換コード
|
||||
# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
|
||||
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
||||
|
||||
def conv_transformer_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
@ -794,7 +805,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
print(f"Reshaping {k} for SD format")
|
||||
# print(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
|
||||
return new_state_dict
|
||||
@ -802,6 +813,11 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
|
||||
# endregion
|
||||
|
||||
# region 自作のモデル読み書き
|
||||
|
||||
def is_safetensors(path):
|
||||
return os.path.splitext(path)[1].lower() == '.safetensors'
|
||||
|
||||
|
||||
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
||||
@ -811,8 +827,16 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
||||
]
|
||||
|
||||
if is_safetensors(ckpt_path):
|
||||
checkpoint = None
|
||||
state_dict = load_file(ckpt_path, "cpu")
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
checkpoint = None
|
||||
|
||||
key_reps = []
|
||||
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
||||
@ -825,13 +849,12 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
|
||||
return checkpoint
|
||||
return checkpoint, state_dict
|
||||
|
||||
|
||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
if dtype is not None:
|
||||
for k, v in state_dict.items():
|
||||
if type(v) is torch.Tensor:
|
||||
@ -962,9 +985,14 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
||||
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
||||
if ckpt_path is not None:
|
||||
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
||||
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
if checkpoint is None: # safetensors または state_dictのckpt
|
||||
checkpoint = {}
|
||||
strict = False
|
||||
else:
|
||||
strict = True
|
||||
if "state_dict" in state_dict:
|
||||
del state_dict["state_dict"]
|
||||
else:
|
||||
# 新しく作る
|
||||
checkpoint = {}
|
||||
@ -1009,6 +1037,10 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
||||
new_ckpt['epoch'] = epochs
|
||||
new_ckpt['global_step'] = steps
|
||||
|
||||
if is_safetensors(output_file):
|
||||
# TODO Tensor以外のdictの値を削除したほうがいいか
|
||||
save_file(state_dict, output_file)
|
||||
else:
|
||||
torch.save(new_ckpt, output_file)
|
||||
|
||||
return key_count
|
||||
@ -1028,3 +1060,107 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
|
||||
requires_safety_checker=None,
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
|
||||
|
||||
VAE_PREFIX = "first_stage_model."
|
||||
|
||||
|
||||
def load_vae(vae_id, dtype):
|
||||
print(f"load VAE: {vae_id}")
|
||||
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||
# Diffusers local/remote
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
||||
except EnvironmentError as e:
|
||||
print(f"exception occurs in loading vae: {e}")
|
||||
print("retry with subfolder='vae'")
|
||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
||||
return vae
|
||||
|
||||
# local
|
||||
vae_config = create_vae_diffusers_config()
|
||||
|
||||
if vae_id.endswith(".bin"):
|
||||
# SD 1.5 VAE on Huggingface
|
||||
vae_sd = torch.load(vae_id, map_location="cpu")
|
||||
converted_vae_checkpoint = vae_sd
|
||||
else:
|
||||
# StableDiffusion
|
||||
vae_model = torch.load(vae_id, map_location="cpu")
|
||||
vae_sd = vae_model['state_dict']
|
||||
|
||||
# vae only or full model
|
||||
full_model = False
|
||||
for vae_key in vae_sd:
|
||||
if vae_key.startswith(VAE_PREFIX):
|
||||
full_model = True
|
||||
break
|
||||
if not full_model:
|
||||
sd = {}
|
||||
for key, value in vae_sd.items():
|
||||
sd[VAE_PREFIX + key] = value
|
||||
vae_sd = sd
|
||||
del sd
|
||||
|
||||
# Convert the VAE model.
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae
|
||||
|
||||
|
||||
def get_epoch_ckpt_name(use_safetensors, epoch):
|
||||
return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt")
|
||||
|
||||
|
||||
def get_last_ckpt_name(use_safetensors):
|
||||
return f"last" + (".safetensors" if use_safetensors else ".ckpt")
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
||||
max_width, max_height = max_reso
|
||||
max_area = (max_width // divisible) * (max_height // divisible)
|
||||
|
||||
resos = set()
|
||||
|
||||
size = int(math.sqrt(max_area)) * divisible
|
||||
resos.add((size, size))
|
||||
|
||||
size = min_size
|
||||
while size <= max_size:
|
||||
width = size
|
||||
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
||||
resos.add((width, height))
|
||||
resos.add((height, width))
|
||||
|
||||
# # make additional resos
|
||||
# if width >= height and width - divisible >= min_size:
|
||||
# resos.add((width - divisible, height))
|
||||
# resos.add((height, width - divisible))
|
||||
# if height >= width and height - divisible >= min_size:
|
||||
# resos.add((width, height - divisible))
|
||||
# resos.add((height - divisible, width))
|
||||
|
||||
size += divisible
|
||||
|
||||
resos = list(resos)
|
||||
resos.sort()
|
||||
|
||||
aspect_ratios = [w / h for w, h in resos]
|
||||
return resos, aspect_ratios
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
resos, aspect_ratios = make_bucket_resolutions((512, 768))
|
||||
print(len(resos))
|
||||
print(resos)
|
||||
print(aspect_ratios)
|
||||
|
||||
ars = set()
|
||||
for ar in aspect_ratios:
|
||||
if ar in ars:
|
||||
print("error! duplicate ar:", ar)
|
||||
ars.add(ar)
|
||||
|
Loading…
Reference in New Issue
Block a user