Update model conversion util
This commit is contained in:
parent
e8db30b9d1
commit
449a35368f
@ -6,7 +6,7 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
import model_util as model_util
|
import model_util
|
||||||
|
|
||||||
|
|
||||||
def convert(args):
|
def convert(args):
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
# v1: split from train_db_fixed.py.
|
# v1: split from train_db_fixed.py.
|
||||||
|
# v2: support safetensors
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
||||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||||
NUM_TRAIN_TIMESTEPS = 1000
|
NUM_TRAIN_TIMESTEPS = 1000
|
||||||
@ -34,7 +37,7 @@ V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
|||||||
|
|
||||||
|
|
||||||
# region StableDiffusion->Diffusersの変換コード
|
# region StableDiffusion->Diffusersの変換コード
|
||||||
# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
|
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
||||||
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
@ -240,21 +243,21 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
|||||||
# Retrieves the keys for the input blocks only
|
# Retrieves the keys for the input blocks only
|
||||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||||
input_blocks = {
|
input_blocks = {
|
||||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
||||||
for layer_id in range(num_input_blocks)
|
for layer_id in range(num_input_blocks)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Retrieves the keys for the middle blocks only
|
# Retrieves the keys for the middle blocks only
|
||||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||||
middle_blocks = {
|
middle_blocks = {
|
||||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
||||||
for layer_id in range(num_middle_blocks)
|
for layer_id in range(num_middle_blocks)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Retrieves the keys for the output blocks only
|
# Retrieves the keys for the output blocks only
|
||||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||||
output_blocks = {
|
output_blocks = {
|
||||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
||||||
for layer_id in range(num_output_blocks)
|
for layer_id in range(num_output_blocks)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,14 +332,22 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
|||||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
# オリジナル:
|
||||||
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||||
f"output_blocks.{i}.{index}.conv.weight"
|
|
||||||
]
|
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
||||||
|
for l in output_block_list.values():
|
||||||
|
l.sort()
|
||||||
|
|
||||||
|
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||||
|
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
||||||
f"output_blocks.{i}.{index}.conv.bias"
|
f"output_blocks.{i}.{index}.conv.bias"
|
||||||
]
|
]
|
||||||
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||||
|
f"output_blocks.{i}.{index}.conv.weight"
|
||||||
|
]
|
||||||
|
|
||||||
# Clear attentions as they have been attributed above.
|
# Clear attentions as they have been attributed above.
|
||||||
if len(attentions) == 2:
|
if len(attentions) == 2:
|
||||||
@ -617,7 +628,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
|||||||
|
|
||||||
|
|
||||||
# region Diffusers->StableDiffusion の変換コード
|
# region Diffusers->StableDiffusion の変換コード
|
||||||
# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
|
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
||||||
|
|
||||||
def conv_transformer_to_linear(checkpoint):
|
def conv_transformer_to_linear(checkpoint):
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
@ -794,7 +805,7 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
for k, v in new_state_dict.items():
|
for k, v in new_state_dict.items():
|
||||||
for weight_name in weights_to_convert:
|
for weight_name in weights_to_convert:
|
||||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||||
print(f"Reshaping {k} for SD format")
|
# print(f"Reshaping {k} for SD format")
|
||||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
@ -802,6 +813,11 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
# region 自作のモデル読み書き
|
||||||
|
|
||||||
|
def is_safetensors(path):
|
||||||
|
return os.path.splitext(path)[1].lower() == '.safetensors'
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||||
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
||||||
@ -811,8 +827,16 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
|||||||
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if is_safetensors(ckpt_path):
|
||||||
|
checkpoint = None
|
||||||
|
state_dict = load_file(ckpt_path, "cpu")
|
||||||
|
else:
|
||||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
if "state_dict" in checkpoint:
|
||||||
state_dict = checkpoint["state_dict"]
|
state_dict = checkpoint["state_dict"]
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
checkpoint = None
|
||||||
|
|
||||||
key_reps = []
|
key_reps = []
|
||||||
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
||||||
@ -825,13 +849,12 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
|||||||
state_dict[new_key] = state_dict[key]
|
state_dict[new_key] = state_dict[key]
|
||||||
del state_dict[key]
|
del state_dict[key]
|
||||||
|
|
||||||
return checkpoint
|
return checkpoint, state_dict
|
||||||
|
|
||||||
|
|
||||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||||
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||||
state_dict = checkpoint["state_dict"]
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if type(v) is torch.Tensor:
|
if type(v) is torch.Tensor:
|
||||||
@ -962,9 +985,14 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
|||||||
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
||||||
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||||
state_dict = checkpoint["state_dict"]
|
if checkpoint is None: # safetensors または state_dictのckpt
|
||||||
|
checkpoint = {}
|
||||||
|
strict = False
|
||||||
|
else:
|
||||||
strict = True
|
strict = True
|
||||||
|
if "state_dict" in state_dict:
|
||||||
|
del state_dict["state_dict"]
|
||||||
else:
|
else:
|
||||||
# 新しく作る
|
# 新しく作る
|
||||||
checkpoint = {}
|
checkpoint = {}
|
||||||
@ -1009,6 +1037,10 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
|||||||
new_ckpt['epoch'] = epochs
|
new_ckpt['epoch'] = epochs
|
||||||
new_ckpt['global_step'] = steps
|
new_ckpt['global_step'] = steps
|
||||||
|
|
||||||
|
if is_safetensors(output_file):
|
||||||
|
# TODO Tensor以外のdictの値を削除したほうがいいか
|
||||||
|
save_file(state_dict, output_file)
|
||||||
|
else:
|
||||||
torch.save(new_ckpt, output_file)
|
torch.save(new_ckpt, output_file)
|
||||||
|
|
||||||
return key_count
|
return key_count
|
||||||
@ -1028,3 +1060,107 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
|
|||||||
requires_safety_checker=None,
|
requires_safety_checker=None,
|
||||||
)
|
)
|
||||||
pipeline.save_pretrained(output_dir)
|
pipeline.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
VAE_PREFIX = "first_stage_model."
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae(vae_id, dtype):
|
||||||
|
print(f"load VAE: {vae_id}")
|
||||||
|
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||||
|
# Diffusers local/remote
|
||||||
|
try:
|
||||||
|
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
||||||
|
except EnvironmentError as e:
|
||||||
|
print(f"exception occurs in loading vae: {e}")
|
||||||
|
print("retry with subfolder='vae'")
|
||||||
|
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
||||||
|
return vae
|
||||||
|
|
||||||
|
# local
|
||||||
|
vae_config = create_vae_diffusers_config()
|
||||||
|
|
||||||
|
if vae_id.endswith(".bin"):
|
||||||
|
# SD 1.5 VAE on Huggingface
|
||||||
|
vae_sd = torch.load(vae_id, map_location="cpu")
|
||||||
|
converted_vae_checkpoint = vae_sd
|
||||||
|
else:
|
||||||
|
# StableDiffusion
|
||||||
|
vae_model = torch.load(vae_id, map_location="cpu")
|
||||||
|
vae_sd = vae_model['state_dict']
|
||||||
|
|
||||||
|
# vae only or full model
|
||||||
|
full_model = False
|
||||||
|
for vae_key in vae_sd:
|
||||||
|
if vae_key.startswith(VAE_PREFIX):
|
||||||
|
full_model = True
|
||||||
|
break
|
||||||
|
if not full_model:
|
||||||
|
sd = {}
|
||||||
|
for key, value in vae_sd.items():
|
||||||
|
sd[VAE_PREFIX + key] = value
|
||||||
|
vae_sd = sd
|
||||||
|
del sd
|
||||||
|
|
||||||
|
# Convert the VAE model.
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||||
|
|
||||||
|
vae = AutoencoderKL(**vae_config)
|
||||||
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
def get_epoch_ckpt_name(use_safetensors, epoch):
|
||||||
|
return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt")
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_ckpt_name(use_safetensors):
|
||||||
|
return f"last" + (".safetensors" if use_safetensors else ".ckpt")
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
||||||
|
max_width, max_height = max_reso
|
||||||
|
max_area = (max_width // divisible) * (max_height // divisible)
|
||||||
|
|
||||||
|
resos = set()
|
||||||
|
|
||||||
|
size = int(math.sqrt(max_area)) * divisible
|
||||||
|
resos.add((size, size))
|
||||||
|
|
||||||
|
size = min_size
|
||||||
|
while size <= max_size:
|
||||||
|
width = size
|
||||||
|
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
||||||
|
resos.add((width, height))
|
||||||
|
resos.add((height, width))
|
||||||
|
|
||||||
|
# # make additional resos
|
||||||
|
# if width >= height and width - divisible >= min_size:
|
||||||
|
# resos.add((width - divisible, height))
|
||||||
|
# resos.add((height, width - divisible))
|
||||||
|
# if height >= width and height - divisible >= min_size:
|
||||||
|
# resos.add((width, height - divisible))
|
||||||
|
# resos.add((height - divisible, width))
|
||||||
|
|
||||||
|
size += divisible
|
||||||
|
|
||||||
|
resos = list(resos)
|
||||||
|
resos.sort()
|
||||||
|
|
||||||
|
aspect_ratios = [w / h for w, h in resos]
|
||||||
|
return resos, aspect_ratios
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
resos, aspect_ratios = make_bucket_resolutions((512, 768))
|
||||||
|
print(len(resos))
|
||||||
|
print(resos)
|
||||||
|
print(aspect_ratios)
|
||||||
|
|
||||||
|
ars = set()
|
||||||
|
for ar in aspect_ratios:
|
||||||
|
if ar in ars:
|
||||||
|
print("error! duplicate ar:", ar)
|
||||||
|
ars.add(ar)
|
||||||
|
Loading…
Reference in New Issue
Block a user