diff --git a/diffusers_fine_tuning/clean_captions_and_tags.py b/diffusers_fine_tuning/clean_captions_and_tags.py index edf557a..76ede34 100644 --- a/diffusers_fine_tuning/clean_captions_and_tags.py +++ b/diffusers_fine_tuning/clean_captions_and_tags.py @@ -71,6 +71,7 @@ def clean_caption(caption): replaced = bef != caption return caption + def main(args): image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) print(f"found {len(image_paths)} images.") @@ -95,16 +96,16 @@ def main(args): return tags = metadata[image_key].get('tags') - caption = metadata[image_key].get('caption') if tags is None: print(f"image does not have tags / メタデータにタグがありません: {image_path}") - return + else: + metadata[image_key]['tags'] = clean_tags(image_key, tags) + + caption = metadata[image_key].get('caption') if caption is None: print(f"image does not have caption / メタデータにキャプションがありません: {image_path}") - return - - metadata[image_key]['tags'] = clean_tags(image_key, tags) - metadata[image_key]['caption'] = clean_caption(caption) + else: + metadata[image_key]['caption'] = clean_caption(caption) # metadataを書き出して終わり print(f"writing metadata: {args.out_json}") diff --git a/diffusers_fine_tuning/fine_tune.py b/diffusers_fine_tuning/fine_tune.py index e433fec..22bb2fd 100644 --- a/diffusers_fine_tuning/fine_tune.py +++ b/diffusers_fine_tuning/fine_tune.py @@ -1,5 +1,6 @@ # v2: select precision for saved checkpoint # v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset) +# v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします # License: @@ -44,12 +45,15 @@ import fine_tuning_utils # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う # checkpointファイル名 LAST_CHECKPOINT_NAME = "last.ckpt" LAST_STATE_NAME = "last-state" +LAST_DIFFUSERS_DIR_NAME = "last" EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt" EPOCH_STATE_NAME = "epoch-{:06d}-state" +EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" def collate_fn(examples): @@ -63,7 +67,7 @@ class FineTuningDataset(torch.utils.data.Dataset): self.metadata = metadata self.train_data_dir = train_data_dir self.batch_size = batch_size - self.tokenizer = tokenizer + self.tokenizer: CLIPTokenizer = tokenizer self.max_token_length = max_token_length self.shuffle_caption = shuffle_caption self.debug = debug @@ -159,17 +163,38 @@ class FineTuningDataset(torch.utils.data.Dataset): input_ids = self.tokenizer(caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に if self.tokenizer_max_length > self.tokenizer.model_max_length: input_ids = input_ids.squeeze(0) iids_list = [] - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): - iid = (input_ids[0].unsqueeze(0), - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) - iid = torch.cat(iid) - iids_list.append(iid) + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = (input_ids[0].unsqueeze(0), + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): + ids_chunk = (input_ids[0].unsqueeze(0), # BOS + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + input_ids = torch.stack(iids_list) # 3,77 input_ids_list.append(input_ids) @@ -192,15 +217,17 @@ def save_hypernetwork(output_file, hypernetwork): def train(args): fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training + # その他のオプション設定を確認する + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + # モデル形式のオプション設定を確認する + # v11からDiffUsersから直接落としてくるのもOK(ただし認証がいるやつは未対応)、またv11からDiffUsersも途中保存に対応した use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) - if not use_stable_diffusion_format: - assert os.path.exists( - args.pretrained_model_name_or_path), f"no pretrained model / 学習元モデルがありません : {args.pretrained_model_name_or_path}" - - assert not fine_tuning or ( - args.save_every_n_epochs is None or use_stable_diffusion_format), "when loading Diffusers model, save_every_n_epochs does not work / Diffusersのモデルを読み込むときにはsave_every_n_epochsオプションは無効になります" + # 乱数系列を初期化する if args.seed is not None: set_seed(args.seed) @@ -215,18 +242,22 @@ def train(args): # tokenizerを読み込む print("prepare tokenizer") - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + if args.max_token_length is not None: - print(f"update token length in tokenizer: {args.max_token_length}") + print(f"update token length: {args.max_token_length}") # datasetを用意する print("prepare dataset") train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size, tokenizer, args.max_token_length, args.shuffle_caption, args.dataset_repeats, args.debug_dataset) + print(f"Total dataset length / データセットの長さ: {len(train_dataset)}") + print(f"Total images / 画像数: {train_dataset.images_count}") if args.debug_dataset: - print(f"Total dataset length / データセットの長さ: {len(train_dataset)}") - print(f"Total images / 画像数: {train_dataset.images_count}") train_dataset.show_buckets() i = 0 for example in train_dataset: @@ -251,14 +282,33 @@ def train(args): accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir) + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 + # モデルを読み込む if use_stable_diffusion_format: print("load StableDiffusion checkpoint") - text_encoder, _, unet = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.pretrained_model_name_or_path) + text_encoder, _, unet = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint( + args.v2, args.pretrained_model_name_or_path) else: print("load Diffusers pretrained models") - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) + # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる + text_encoder = pipe.text_encoder + unet = pipe.unet + del pipe # モデルに xformers とか memory efficient attention を組み込む replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -279,21 +329,6 @@ def train(args): print("apply hypernetwork") hypernetwork.apply_to_diffusers(None, text_encoder, unet) - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 - # 学習を準備する:モデルを適切な状態にする training_models = [] if fine_tuning: @@ -351,7 +386,7 @@ def train(args): # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( - "constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) # acceleratorがなんかよろしくやってくれるらしい if fine_tuning: @@ -384,10 +419,14 @@ def train(args): print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps") + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) + # v4で更新:clip_sample=Falseに + # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる + # 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ + noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", + num_train_timesteps=1000, clip_sample=False) if accelerator.is_main_process: accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork") @@ -400,7 +439,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(training_models[0]): # ここはこれでいいのか……? + with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく latents = batch["latents"].to(accelerator.device) latents = latents * 0.18215 b_size = latents.shape[0] @@ -418,15 +457,29 @@ def train(args): encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + # bs*3, 77, 768 or 1024 encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) if args.max_token_length is not None: - # ... の三連を ... へ戻す - sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)] - for i in range(1, args.max_token_length, tokenizer.model_max_length): - sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) - sts_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) - encoder_hidden_states = torch.cat(sts_list, dim=1) + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -442,7 +495,41 @@ def train(args): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + if args.v_parameterization: + # v-parameterization training + + # 11/29現在v predictionのコードがDiffusersにcommitされたがリリースされていないので独自コードを使う + # 実装の中身は同じ模様 + + # こうしたい: + # target = noise_scheduler.get_v(latents, noise, timesteps) + + # StabilityAiのddpm.pyのコード: + # elif self.parameterization == "v": + # target = self.get_v(x_start, noise, t) + # ... + # def get_v(self, x, noise, t): + # return ( + # extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + # extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + # ) + + # scheduling_ddim.pyのコード: + # elif self.config.prediction_type == "v_prediction": + # pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # # predict V + # model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # これでいいかな?: + alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps] + beta_prod_t = 1 - alpha_prod_t + alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape + beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1)) + target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: @@ -460,7 +547,7 @@ def train(args): progress_bar.update(1) global_step += 1 - current_loss = loss.detach().item() * b_size + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} accelerator.log(logs, step=global_step) @@ -481,14 +568,20 @@ def train(args): if args.save_every_n_epochs is not None: if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: - print("saving check point.") + print("saving checkpoint.") os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1)) if fine_tuning: - fine_tuning_utils.save_stable_diffusion_checkpoint( - ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet), - args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype) + if use_stable_diffusion_format: + fine_tuning_utils.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet), + args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype) + else: + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) + os.makedirs(out_dir, exist_ok=True) + fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), args.pretrained_model_name_or_path, save_dtype) else: save_hypernetwork(ckpt_file, accelerator.unwrap_model(hypernetwork)) @@ -519,16 +612,14 @@ def train(args): ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME) print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") fine_tuning_utils.save_stable_diffusion_checkpoint( - ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype) + args.v2, ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype) else: # Create the pipeline using using the trained modules and save it. print(f"save trained model as Diffusers to {args.output_dir}") - pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=unet, - text_encoder=text_encoder, - ) - pipeline.save_pretrained(args.output_dir) + out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME) + os.makedirs(out_dir, exist_ok=True) + fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + args.pretrained_model_name_or_path, save_dtype) else: ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME) print(f"save trained model to {ckpt_file}") @@ -817,6 +908,10 @@ def replace_unet_cross_attn_to_xformers(): if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') + parser.add_argument("--v_parameterization", action='store_true', + help='enable v-parameterization training / v-parameterization学習を有効にする') parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル") @@ -832,7 +927,7 @@ if __name__ == '__main__': parser.add_argument("--hypernetwork_weights", type=str, default=None, help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)') parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存する(StableDiffusion形式のモデルを読み込んだ場合のみ有効)") + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, @@ -857,13 +952,17 @@ if __name__ == '__main__': parser.add_argument("--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") + choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)") parser.add_argument("--clip_skip", type=int, default=None, help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") parser.add_argument("--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") parser.add_argument("--logging_dir", type=str, default=None, help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") + parser.add_argument("--lr_scheduler", type=str, default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") args = parser.parse_args() train(args) diff --git a/diffusers_fine_tuning/fine_tuning_utils.py b/diffusers_fine_tuning/fine_tuning_utils.py index a64b74c..f12cb90 100644 --- a/diffusers_fine_tuning/fine_tuning_utils.py +++ b/diffusers_fine_tuning/fine_tuning_utils.py @@ -2,11 +2,13 @@ import math import torch from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -# StableDiffusionのモデルパラメータ +# region checkpoint変換、読み込み、書き込み ############################### + +# DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 BETA_START = 0.00085 BETA_END = 0.0120 @@ -29,13 +31,15 @@ VAE_PARAMS_CH = 128 VAE_PARAMS_CH_MULT = [1, 2, 4, 4] VAE_PARAMS_NUM_RES_BLOCKS = 2 +# V2 +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] +V2_UNET_PARAMS_CONTEXT_DIM = 1024 -# region conversion -# checkpoint変換など ############################### # region StableDiffusion->Diffusersの変換コード # convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0) + def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. @@ -199,7 +203,16 @@ def conv_attn_to_linear(checkpoint): checkpoint[key] = checkpoint[key][:, :, 0] -def convert_ldm_unet_checkpoint(checkpoint, config): +def linear_transformer_to_conv(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + + +def convert_ldm_unet_checkpoint(v2, checkpoint, config): """ Takes a state dict and a config, and returns a converted checkpoint. """ @@ -349,6 +362,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config): new_checkpoint[new_path] = unet_state_dict[old_path] + # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する + if v2: + linear_transformer_to_conv(new_checkpoint) + return new_checkpoint @@ -459,7 +476,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config): return new_checkpoint -def create_unet_diffusers_config(): +def create_unet_diffusers_config(v2): """ Creates a config for the diffusers based on the config of the LDM model. """ @@ -489,8 +506,8 @@ def create_unet_diffusers_config(): up_block_types=tuple(up_block_types), block_out_channels=tuple(block_out_channels), layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, ) return config @@ -519,20 +536,82 @@ def create_vae_diffusers_config(): return config -def convert_ldm_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - +def convert_ldm_clip_checkpoint_v1(checkpoint): keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] + return text_model_dict - text_model.load_state_dict(text_model_dict) - return text_model +def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None + + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif '.attn.out_proj' in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif '.attn.in_proj' in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif '.positional_embedding' in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif '.text_projection' in key: + key = None # 使われない??? + elif '.logit_scale' in key: + key = None # 使われない??? + elif '.token_embedding' in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif '.ln_final' in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if '.resblocks.23.' in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if '.resblocks.23.' in key: + continue + if '.resblocks' in key and '.attn.in_proj_' in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # position_idsの追加 + new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) + return new_sd # endregion @@ -540,7 +619,16 @@ def convert_ldm_clip_checkpoint(checkpoint): # region Diffusers->StableDiffusion の変換コード # convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0) -def convert_unet_state_dict(unet_state_dict): +def conv_transformer_to_linear(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + + +def convert_unet_state_dict_to_sd(v2, unet_state_dict): unet_conversion_map = [ # (stable-diffusion, HF Diffusers) ("time_embed.0.weight", "time_embedding.linear_1.weight"), @@ -629,12 +717,16 @@ def convert_unet_state_dict(unet_state_dict): v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + return new_state_dict # endregion -def load_checkpoint_with_conversion(ckpt_path): +def load_checkpoint_with_text_encoder_conversion(ckpt_path): # text encoderの格納形式が違うモデルに対応する ('text_model'がない) TEXT_ENCODER_KEY_REPLACEMENTS = [ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), @@ -659,52 +751,148 @@ def load_checkpoint_with_conversion(ckpt_path): return checkpoint -def load_models_from_stable_diffusion_checkpoint(ckpt_path): - checkpoint = load_checkpoint_with_conversion(ckpt_path) +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): + checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) state_dict = checkpoint["state_dict"] + if dtype is not None: + for k, v in state_dict.items(): + if type(v) is torch.Tensor: + state_dict[k] = v.to(dtype) # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config() - converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config) + unet_config = create_unet_diffusers_config(v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) unet = UNet2DConditionModel(**unet_config) - unet.load_state_dict(converted_unet_checkpoint) + info = unet.load_state_dict(converted_unet_checkpoint) + print("loading u-net:", info) # Convert the VAE model. vae_config = create_vae_diffusers_config() converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) + info = vae.load_state_dict(converted_vae_checkpoint) + print("loadint vae:", info) # convert text_model - text_model = convert_ldm_clip_checkpoint(state_dict) + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print("loading text encoder:", info) return text_model, vae, unet -def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None): +def convert_text_encoder_state_dict_to_sd_v2(checkpoint): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif '.self_attn.out_proj' in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif '.self_attn.' in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif '.position_embedding' in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif '.token_embedding' in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif 'final_layer_norm' in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if 'layers' in key and 'q_proj' in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + return new_sd + + +def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None): # VAEがメモリ上にないので、もう一度VAEを含めて読み込む - checkpoint = load_checkpoint_with_conversion(ckpt_path) + checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) state_dict = checkpoint["state_dict"] + def assign_new_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + # Convert the UNet model - unet_state_dict = convert_unet_state_dict(unet.state_dict()) - for k, v in unet_state_dict.items(): - key = "model.diffusion_model." + k - assert key in state_dict, f"Illegal key in save SD: {key}" - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + assign_new_sd("model.diffusion_model.", unet_state_dict) # Convert the text encoder model - text_enc_dict = text_encoder.state_dict() # 変換不要 - for k, v in text_enc_dict.items(): - key = "cond_stage_model.transformer." + k - assert key in state_dict, f"Illegal key in save SD: {key}" - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v + if v2: + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict()) + assign_new_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + assign_new_sd("cond_stage_model.transformer.", text_enc_dict) # Put together new checkpoint new_ckpt = {'state_dict': state_dict} @@ -718,6 +906,21 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, new_ckpt['global_step'] = steps torch.save(new_ckpt, output_file) + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, save_dtype): + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae"), + scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"), + tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"), + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir) + # endregion diff --git a/diffusers_fine_tuning/merge_captions_to_metadata.py b/diffusers_fine_tuning/merge_captions_to_metadata.py index a50d2bd..fc8d124 100644 --- a/diffusers_fine_tuning/merge_captions_to_metadata.py +++ b/diffusers_fine_tuning/merge_captions_to_metadata.py @@ -10,7 +10,7 @@ from tqdm import tqdm def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") if args.in_json is not None: diff --git a/diffusers_fine_tuning/merge_dd_tags_to_metadata.py b/diffusers_fine_tuning/merge_dd_tags_to_metadata.py index 6436e6a..84b92a0 100644 --- a/diffusers_fine_tuning/merge_dd_tags_to_metadata.py +++ b/diffusers_fine_tuning/merge_dd_tags_to_metadata.py @@ -10,7 +10,7 @@ from tqdm import tqdm def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") if args.in_json is not None: diff --git a/diffusers_fine_tuning/prepare_buckets_latents.py b/diffusers_fine_tuning/prepare_buckets_latents.py index 864205a..5420fed 100644 --- a/diffusers_fine_tuning/prepare_buckets_latents.py +++ b/diffusers_fine_tuning/prepare_buckets_latents.py @@ -36,7 +36,7 @@ def get_latents(vae, images, weight_dtype): def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") if os.path.exists(args.in_json): @@ -49,13 +49,11 @@ def main(args): # モデル形式のオプション設定を確認する use_stable_diffusion_format = os.path.isfile(args.model_name_or_path) - if not use_stable_diffusion_format: - assert os.path.exists(args.model_name_or_path), f"no model / モデルがありません : {args.model_name_or_path}" # モデルを読み込む if use_stable_diffusion_format: print("load StableDiffusion checkpoint") - _, vae, _ = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.model_name_or_path) + _, vae, _ = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_name_or_path) else: print("load Diffusers pretrained models") vae = AutoencoderKL.from_pretrained(args.model_name_or_path, subfolder="vae") @@ -73,7 +71,8 @@ def main(args): max_reso = tuple([int(t) for t in args.max_resolution.split(',')]) assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" - bucket_resos, bucket_aspect_ratios = fine_tuning_utils.make_bucket_resolutions(max_reso) + bucket_resos, bucket_aspect_ratios = fine_tuning_utils.make_bucket_resolutions( + max_reso, args.min_bucket_reso, args.max_bucket_reso) # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する bucket_aspect_ratios = np.array(bucket_aspect_ratios) @@ -162,9 +161,13 @@ if __name__ == '__main__': parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--max_resolution", type=str, default="512,512", help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") parser.add_argument("--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")