Update finetune_py to V4

This commit is contained in:
bmaltais 2022-12-02 12:48:43 -05:00
parent 621dabcadf
commit 95a694a3fd
6 changed files with 422 additions and 116 deletions

View File

@ -71,6 +71,7 @@ def clean_caption(caption):
replaced = bef != caption
return caption
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
print(f"found {len(image_paths)} images.")
@ -95,16 +96,16 @@ def main(args):
return
tags = metadata[image_key].get('tags')
caption = metadata[image_key].get('caption')
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_path}")
return
else:
metadata[image_key]['tags'] = clean_tags(image_key, tags)
caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_path}")
return
metadata[image_key]['tags'] = clean_tags(image_key, tags)
metadata[image_key]['caption'] = clean_caption(caption)
else:
metadata[image_key]['caption'] = clean_caption(caption)
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")

View File

@ -1,5 +1,6 @@
# v2: select precision for saved checkpoint
# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
# v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
# License:
@ -44,12 +45,15 @@ import fine_tuning_utils
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
# checkpointファイル名
LAST_CHECKPOINT_NAME = "last.ckpt"
LAST_STATE_NAME = "last-state"
LAST_DIFFUSERS_DIR_NAME = "last"
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
EPOCH_STATE_NAME = "epoch-{:06d}-state"
EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
def collate_fn(examples):
@ -63,7 +67,7 @@ class FineTuningDataset(torch.utils.data.Dataset):
self.metadata = metadata
self.train_data_dir = train_data_dir
self.batch_size = batch_size
self.tokenizer = tokenizer
self.tokenizer: CLIPTokenizer = tokenizer
self.max_token_length = max_token_length
self.shuffle_caption = shuffle_caption
self.debug = debug
@ -159,17 +163,38 @@ class FineTuningDataset(torch.utils.data.Dataset):
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
if self.tokenizer_max_length > self.tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
iid = (input_ids[0].unsqueeze(0),
input_ids[i:i + self.tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0))
iid = torch.cat(iid)
iids_list.append(iid)
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
ids_chunk = (input_ids[0].unsqueeze(0),
input_ids[i:i + self.tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0))
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
input_ids[i:i + self.tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0)) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
ids_chunk[-1] = self.tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == self.tokenizer.pad_token_id:
ids_chunk[1] = self.tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
input_ids_list.append(input_ids)
@ -192,15 +217,17 @@ def save_hypernetwork(output_file, hypernetwork):
def train(args):
fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training
# その他のオプション設定を確認する
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
# モデル形式のオプション設定を確認する
# v11からDiffUsersから直接落としてくるのもOKただし認証がいるやつは未対応、またv11からDiffUsersも途中保存に対応した
use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
if not use_stable_diffusion_format:
assert os.path.exists(
args.pretrained_model_name_or_path), f"no pretrained model / 学習元モデルがありません : {args.pretrained_model_name_or_path}"
assert not fine_tuning or (
args.save_every_n_epochs is None or use_stable_diffusion_format), "when loading Diffusers model, save_every_n_epochs does not work / Diffusersのモデルを読み込むときにはsave_every_n_epochsオプションは無効になります"
# 乱数系列を初期化する
if args.seed is not None:
set_seed(args.seed)
@ -215,18 +242,22 @@ def train(args):
# tokenizerを読み込む
print("prepare tokenizer")
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
if args.max_token_length is not None:
print(f"update token length in tokenizer: {args.max_token_length}")
print(f"update token length: {args.max_token_length}")
# datasetを用意する
print("prepare dataset")
train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size,
tokenizer, args.max_token_length, args.shuffle_caption, args.dataset_repeats, args.debug_dataset)
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
print(f"Total images / 画像数: {train_dataset.images_count}")
if args.debug_dataset:
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
print(f"Total images / 画像数: {train_dataset.images_count}")
train_dataset.show_buckets()
i = 0
for example in train_dataset:
@ -251,14 +282,33 @@ def train(args):
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
# モデルを読み込む
if use_stable_diffusion_format:
print("load StableDiffusion checkpoint")
text_encoder, _, unet = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.pretrained_model_name_or_path)
text_encoder, _, unet = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(
args.v2, args.pretrained_model_name_or_path)
else:
print("load Diffusers pretrained models")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
# , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる
text_encoder = pipe.text_encoder
unet = pipe.unet
del pipe
# モデルに xformers とか memory efficient attention を組み込む
replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@ -279,21 +329,6 @@ def train(args):
print("apply hypernetwork")
hypernetwork.apply_to_diffusers(None, text_encoder, unet)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
# 学習を準備する:モデルを適切な状態にする
training_models = []
if fine_tuning:
@ -351,7 +386,7 @@ def train(args):
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
"constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# acceleratorがなんかよろしくやってくれるらしい
if fine_tuning:
@ -384,10 +419,14 @@ def train(args):
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
# v4で更新clip_sample=Falseに
# Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
# 既存の1.4/1.5/2.0はすべてschdulerのconfigはクラス名を除いて同じ
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process:
accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork")
@ -400,7 +439,7 @@ def train(args):
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(training_models[0]): # ここはこれでいいのか……?
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
latents = batch["latents"].to(accelerator.device)
latents = latents * 0.18215
b_size = latents.shape[0]
@ -418,15 +457,29 @@ def train(args):
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None:
# <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
for i in range(1, args.max_token_length, tokenizer.model_max_length):
sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2])
sts_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
encoder_hidden_states = torch.cat(sts_list, dim=1)
if args.v2:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
@ -442,7 +495,41 @@ def train(args):
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
if args.v_parameterization:
# v-parameterization training
# 11/29現在v predictionのコードがDiffusersにcommitされたがリリースされていないので独自コードを使う
# 実装の中身は同じ模様
# こうしたい:
# target = noise_scheduler.get_v(latents, noise, timesteps)
# StabilityAiのddpm.pyのコード
# elif self.parameterization == "v":
# target = self.get_v(x_start, noise, t)
# ...
# def get_v(self, x, noise, t):
# return (
# extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
# extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
# )
# scheduling_ddim.pyのコード
# elif self.config.prediction_type == "v_prediction":
# pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# # predict V
# model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
# これでいいかな?:
alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps]
beta_prod_t = 1 - alpha_prod_t
alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape
beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1))
target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@ -460,7 +547,7 @@ def train(args):
progress_bar.update(1)
global_step += 1
current_loss = loss.detach().item() * b_size
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
@ -481,14 +568,20 @@ def train(args):
if args.save_every_n_epochs is not None:
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
print("saving check point.")
print("saving checkpoint.")
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
if fine_tuning:
fine_tuning_utils.save_stable_diffusion_checkpoint(
ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
if use_stable_diffusion_format:
fine_tuning_utils.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
else:
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
os.makedirs(out_dir, exist_ok=True)
fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet), args.pretrained_model_name_or_path, save_dtype)
else:
save_hypernetwork(ckpt_file, accelerator.unwrap_model(hypernetwork))
@ -519,16 +612,14 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
fine_tuning_utils.save_stable_diffusion_checkpoint(
ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
args.v2, ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
else:
# Create the pipeline using using the trained modules and save it.
print(f"save trained model as Diffusers to {args.output_dir}")
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
text_encoder=text_encoder,
)
pipeline.save_pretrained(args.output_dir)
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
os.makedirs(out_dir, exist_ok=True)
fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
args.pretrained_model_name_or_path, save_dtype)
else:
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
print(f"save trained model to {ckpt_file}")
@ -817,6 +908,10 @@ def replace_unet_cross_attn_to_xformers():
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--v_parameterization", action='store_true',
help='enable v-parameterization training / v-parameterization学習を有効にする')
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル")
@ -832,7 +927,7 @@ if __name__ == '__main__':
parser.add_argument("--hypernetwork_weights", type=str, default=None,
help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重みHypernetworkの追加学習')
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存するStableDiffusion形式のモデルを読み込んだ場合のみ有効")
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None,
@ -857,13 +952,17 @@ if __name__ == '__main__':
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存するStableDiffusion形式での保存時のみ有効")
parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--debug_dataset", action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
args = parser.parse_args()
train(args)

View File

@ -2,11 +2,13 @@ import math
import torch
from transformers import CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
# StableDiffusionのモデルパラメータ
# region checkpoint変換、読み込み、書き込み ###############################
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000
BETA_START = 0.00085
BETA_END = 0.0120
@ -29,13 +31,15 @@ VAE_PARAMS_CH = 128
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
VAE_PARAMS_NUM_RES_BLOCKS = 2
# V2
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
V2_UNET_PARAMS_CONTEXT_DIM = 1024
# region conversion
# checkpoint変換など ###############################
# region StableDiffusion->Diffusersの変換コード
# convert_original_stable_diffusion_to_diffusers をコピーしているASL 2.0
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
@ -199,7 +203,16 @@ def conv_attn_to_linear(checkpoint):
checkpoint[key] = checkpoint[key][:, :, 0]
def convert_ldm_unet_checkpoint(checkpoint, config):
def linear_transformer_to_conv(checkpoint):
keys = list(checkpoint.keys())
tf_keys = ["proj_in.weight", "proj_out.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in tf_keys:
if checkpoint[key].ndim == 2:
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
@ -349,6 +362,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config):
new_checkpoint[new_path] = unet_state_dict[old_path]
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
if v2:
linear_transformer_to_conv(new_checkpoint)
return new_checkpoint
@ -459,7 +476,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint
def create_unet_diffusers_config():
def create_unet_diffusers_config(v2):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
@ -489,8 +506,8 @@ def create_unet_diffusers_config():
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM,
attention_head_dim=UNET_PARAMS_NUM_HEADS,
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
)
return config
@ -519,20 +536,82 @@ def create_vae_diffusers_config():
return config
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
def convert_ldm_clip_checkpoint_v1(checkpoint):
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
return text_model_dict
text_model.load_state_dict(text_model_dict)
return text_model
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
# 嫌になるくらい違うぞ!
def convert_key(key):
if not key.startswith("cond_stage_model"):
return None
# common conversion
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
key = key.replace("cond_stage_model.model.", "text_model.")
if "resblocks" in key:
# resblocks conversion
key = key.replace(".resblocks.", ".layers.")
if ".ln_" in key:
key = key.replace(".ln_", ".layer_norm")
elif ".mlp." in key:
key = key.replace(".c_fc.", ".fc1.")
key = key.replace(".c_proj.", ".fc2.")
elif '.attn.out_proj' in key:
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
elif '.attn.in_proj' in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f"unexpected key in SD: {key}")
elif '.positional_embedding' in key:
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
elif '.text_projection' in key:
key = None # 使われない???
elif '.logit_scale' in key:
key = None # 使われない???
elif '.token_embedding' in key:
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
elif '.ln_final' in key:
key = key.replace(".ln_final", ".final_layer_norm")
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
# remove resblocks 23
if '.resblocks.23.' in key:
continue
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if '.resblocks.23.' in key:
continue
if '.resblocks' in key and '.attn.in_proj_' in key:
# 三つに分割
values = torch.chunk(checkpoint[key], 3)
key_suffix = ".weight" if "weight" in key else ".bias"
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
key_pfx = key_pfx.replace("_weight", "")
key_pfx = key_pfx.replace("_bias", "")
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# position_idsの追加
new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
return new_sd
# endregion
@ -540,7 +619,16 @@ def convert_ldm_clip_checkpoint(checkpoint):
# region Diffusers->StableDiffusion の変換コード
# convert_diffusers_to_original_stable_diffusion をコピーしているASL 2.0
def convert_unet_state_dict(unet_state_dict):
def conv_transformer_to_linear(checkpoint):
keys = list(checkpoint.keys())
tf_keys = ["proj_in.weight", "proj_out.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in tf_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
@ -629,12 +717,16 @@ def convert_unet_state_dict(unet_state_dict):
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
if v2:
conv_transformer_to_linear(new_state_dict)
return new_state_dict
# endregion
def load_checkpoint_with_conversion(ckpt_path):
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
TEXT_ENCODER_KEY_REPLACEMENTS = [
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
@ -659,52 +751,148 @@ def load_checkpoint_with_conversion(ckpt_path):
return checkpoint
def load_models_from_stable_diffusion_checkpoint(ckpt_path):
checkpoint = load_checkpoint_with_conversion(ckpt_path)
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
state_dict = checkpoint["state_dict"]
if dtype is not None:
for k, v in state_dict.items():
if type(v) is torch.Tensor:
state_dict[k] = v.to(dtype)
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config()
converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
unet_config = create_unet_diffusers_config(v2)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
unet = UNet2DConditionModel(**unet_config)
unet.load_state_dict(converted_unet_checkpoint)
info = unet.load_state_dict(converted_unet_checkpoint)
print("loading u-net:", info)
# Convert the VAE model.
vae_config = create_vae_diffusers_config()
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
info = vae.load_state_dict(converted_vae_checkpoint)
print("loadint vae:", info)
# convert text_model
text_model = convert_ldm_clip_checkpoint(state_dict)
if v2:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=23,
num_attention_heads=16,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=512,
torch_dtype="float32",
transformers_version="4.25.0.dev0",
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
else:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info)
return text_model, vae, unet
def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
def convert_key(key):
# position_idsの除去
if ".position_ids" in key:
return None
# common
key = key.replace("text_model.encoder.", "transformer.")
key = key.replace("text_model.", "")
if "layers" in key:
# resblocks conversion
key = key.replace(".layers.", ".resblocks.")
if ".layer_norm" in key:
key = key.replace(".layer_norm", ".ln_")
elif ".mlp." in key:
key = key.replace(".fc1.", ".c_fc.")
key = key.replace(".fc2.", ".c_proj.")
elif '.self_attn.out_proj' in key:
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
elif '.self_attn.' in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f"unexpected key in DiffUsers model: {key}")
elif '.position_embedding' in key:
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif '.token_embedding' in key:
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif 'final_layer_norm' in key:
key = key.replace("final_layer_norm", "ln_final")
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if 'layers' in key and 'q_proj' in key:
# 三つを結合
key_q = key
key_k = key.replace("q_proj", "k_proj")
key_v = key.replace("q_proj", "v_proj")
value_q = checkpoint[key_q]
value_k = checkpoint[key_k]
value_v = checkpoint[key_v]
value = torch.cat([value_q, value_k, value_v])
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
new_sd[new_key] = value
return new_sd
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む
checkpoint = load_checkpoint_with_conversion(ckpt_path)
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
state_dict = checkpoint["state_dict"]
def assign_new_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
assert key in state_dict, f"Illegal key in save SD: {key}"
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet.state_dict())
for k, v in unet_state_dict.items():
key = "model.diffusion_model." + k
assert key in state_dict, f"Illegal key in save SD: {key}"
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
assign_new_sd("model.diffusion_model.", unet_state_dict)
# Convert the text encoder model
text_enc_dict = text_encoder.state_dict() # 変換不要
for k, v in text_enc_dict.items():
key = "cond_stage_model.transformer." + k
assert key in state_dict, f"Illegal key in save SD: {key}"
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if v2:
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict())
assign_new_sd("cond_stage_model.model.", text_enc_dict)
else:
text_enc_dict = text_encoder.state_dict()
assign_new_sd("cond_stage_model.transformer.", text_enc_dict)
# Put together new checkpoint
new_ckpt = {'state_dict': state_dict}
@ -718,6 +906,21 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
new_ckpt['global_step'] = steps
torch.save(new_ckpt, output_file)
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, save_dtype):
pipeline = StableDiffusionPipeline(
unet=unet,
text_encoder=text_encoder,
vae=AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae"),
scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
safety_checker=None,
feature_extractor=None,
requires_safety_checker=None,
)
pipeline.save_pretrained(output_dir)
# endregion

View File

@ -10,7 +10,7 @@ from tqdm import tqdm
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")
if args.in_json is not None:

View File

@ -10,7 +10,7 @@ from tqdm import tqdm
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")
if args.in_json is not None:

View File

@ -36,7 +36,7 @@ def get_latents(vae, images, weight_dtype):
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
@ -49,13 +49,11 @@ def main(args):
# モデル形式のオプション設定を確認する
use_stable_diffusion_format = os.path.isfile(args.model_name_or_path)
if not use_stable_diffusion_format:
assert os.path.exists(args.model_name_or_path), f"no model / モデルがありません : {args.model_name_or_path}"
# モデルを読み込む
if use_stable_diffusion_format:
print("load StableDiffusion checkpoint")
_, vae, _ = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.model_name_or_path)
_, vae, _ = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_name_or_path)
else:
print("load Diffusers pretrained models")
vae = AutoencoderKL.from_pretrained(args.model_name_or_path, subfolder="vae")
@ -73,7 +71,8 @@ def main(args):
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_resos, bucket_aspect_ratios = fine_tuning_utils.make_bucket_resolutions(max_reso)
bucket_resos, bucket_aspect_ratios = fine_tuning_utils.make_bucket_resolutions(
max_reso, args.min_bucket_reso, args.max_bucket_reso)
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
@ -162,9 +161,13 @@ if __name__ == '__main__':
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")