diff --git a/diffusers_fine_tuning/README.md b/diffusers_fine_tuning/README.md index 993537e..20b95f9 100644 --- a/diffusers_fine_tuning/README.md +++ b/diffusers_fine_tuning/README.md @@ -1,3 +1,10 @@ # Diffusers Fine Tuning -This subfolder provide all the required toold to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29 +This subfolder provide all the required tools to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29 + +## Releases + +11/23 (v3): +- Added WD14Tagger tagging script. +- A log output function has been added to the fine_tune.py. Also, fixed the double shuffling of data. +- Fixed misspelling of options for each script (caption_extention→caption_extension will work for the time being, even if it remains outdated). diff --git a/diffusers_fine_tuning/clean_captions_and_tags.py b/diffusers_fine_tuning/clean_captions_and_tags.py index 91503a7..edf557a 100644 --- a/diffusers_fine_tuning/clean_captions_and_tags.py +++ b/diffusers_fine_tuning/clean_captions_and_tags.py @@ -13,11 +13,13 @@ def clean_tags(image_key, tags): # replace '_' to ' ' tags = tags.replace('_', ' ') - # remove rating + # remove rating: deepdanbooruのみ tokens = tags.split(", rating") if len(tokens) == 1: - print("no rating:") - print(f"{image_key} {tags}") + # WD14 taggerのときはこちらになるのでメッセージは出さない + # print("no rating:") + # print(f"{image_key} {tags}") + pass else: if len(tokens) > 2: print("multiple ratings:") diff --git a/diffusers_fine_tuning/fine_tune.py b/diffusers_fine_tuning/fine_tune.py index 1cc35f4..e433fec 100644 --- a/diffusers_fine_tuning/fine_tune.py +++ b/diffusers_fine_tuning/fine_tune.py @@ -1,5 +1,5 @@ # v2: select precision for saved checkpoint - +# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset) # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします # License: @@ -22,12 +22,12 @@ import argparse -import itertools import math import os import random import json import importlib +import time from tqdm import tqdm import torch @@ -159,7 +159,7 @@ class FineTuningDataset(torch.utils.data.Dataset): input_ids = self.tokenizer(caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に if self.tokenizer_max_length > self.tokenizer.model_max_length: input_ids = input_ids.squeeze(0) @@ -242,7 +242,14 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision) + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime()) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir) # モデルを読み込む if use_stable_diffusion_format: @@ -304,7 +311,7 @@ def train(args): text_encoder.requires_grad_(False) # text encoderは学習しない text_encoder.eval() else: - unet.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない unet.requires_grad_(False) unet.eval() text_encoder.to(accelerator.device, dtype=weight_dtype) @@ -314,7 +321,10 @@ def train(args): for m in training_models: m.requires_grad_(True) - params_to_optimize = itertools.chain(*[m.parameters() for m in training_models]) + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") @@ -337,11 +347,11 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps) + "constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) # acceleratorがなんかよろしくやってくれるらしい if fine_tuning: @@ -390,7 +400,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(training_models[0]): + with accelerator.accumulate(training_models[0]): # ここはこれでいいのか……? latents = batch["latents"].to(accelerator.device) latents = latents * 0.18215 b_size = latents.shape[0] @@ -411,7 +421,7 @@ def train(args): encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) if args.max_token_length is not None: - # ... の三連を ... へ戻す + # ... の三連を ... へ戻す sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)] for i in range(1, args.max_token_length, tokenizer.model_max_length): sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) @@ -436,7 +446,9 @@ def train(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = itertools.chain(*[m.parameters() for m in training_models]) + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) optimizer.step() @@ -449,15 +461,22 @@ def train(args): global_step += 1 current_loss = loss.detach().item() * b_size + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + accelerator.log(logs, step=global_step) + loss_total += current_loss avr_loss = loss_total / (step+1) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - # accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break + if args.logging_dir is not None: + logs = {"epoch_loss": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch+1) + accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: @@ -843,7 +862,8 @@ if __name__ == '__main__': help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") parser.add_argument("--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--logging_dir", type=str, default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") + args = parser.parse_args() train(args) diff --git a/diffusers_fine_tuning/make_captions.py b/diffusers_fine_tuning/make_captions.py index d3e42fd..44f1e53 100644 --- a/diffusers_fine_tuning/make_captions.py +++ b/diffusers_fine_tuning/make_captions.py @@ -48,7 +48,7 @@ def main(args): captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) for (image_path, _), caption in zip(path_imgs, captions): - with open(os.path.splitext(image_path)[0] + args.caption_extention, "wt", encoding='utf-8') as f: + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: f.write(caption + "\n") if args.debug: print(image_path, caption) @@ -76,7 +76,9 @@ if __name__ == '__main__': parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("caption_weights", type=str, help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") - parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--beam_search", action="store_true", help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") @@ -87,4 +89,9 @@ if __name__ == '__main__': parser.add_argument("--debug", action="store_true", help="debug mode") args = parser.parse_args() + + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + main(args) diff --git a/diffusers_fine_tuning/merge_captions_to_metadata.py b/diffusers_fine_tuning/merge_captions_to_metadata.py index cfc97e1..a50d2bd 100644 --- a/diffusers_fine_tuning/merge_captions_to_metadata.py +++ b/diffusers_fine_tuning/merge_captions_to_metadata.py @@ -24,7 +24,7 @@ def main(args): print("merge caption texts to metadata json.") for image_path in tqdm(image_paths): - caption_path = os.path.splitext(image_path)[0] + args.caption_extention + caption_path = os.path.splitext(image_path)[0] + args.caption_extension with open(caption_path, "rt", encoding='utf-8') as f: caption = f.readlines()[0].strip() @@ -54,8 +54,15 @@ if __name__ == '__main__': parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") - parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption file / 読み込むキャプションファイルの拡張子") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") parser.add_argument("--debug", action="store_true", help="debug mode") args = parser.parse_args() + + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + main(args) diff --git a/diffusers_fine_tuning/tag_images_by_wd14_tagger.py b/diffusers_fine_tuning/tag_images_by_wd14_tagger.py new file mode 100644 index 0000000..66d3a34 --- /dev/null +++ b/diffusers_fine_tuning/tag_images_by_wd14_tagger.py @@ -0,0 +1,107 @@ +# このスクリプトのライセンスは、Apache License 2.0とします +# (c) 2022 Kohya S. @kohya_ss + +import argparse +import csv +import glob +import os +import json + +from PIL import Image +from tqdm import tqdm +import numpy as np +from tensorflow.keras.models import load_model +from Utils import dbimutils + + +# from wd14 tagger +IMAGE_SIZE = 448 + + +def main(args): + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + print(f"found {len(image_paths)} images.") + + print("loading model and labels") + model = load_model(args.model) + + # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") + # 依存ライブラリを増やしたくないので自力で読むよ + with open(args.tag_csv, "r", encoding="utf-8") as f: + reader = csv.reader(f) + l = [row for row in reader] + header = l[0] # tag_id,name,category,count + rows = l[1:] + assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" + + tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ + + # 推論する + def run_batch(path_imgs): + imgs = np.array([im for _, im in path_imgs]) + + probs = model(imgs, training=False) + probs = probs.numpy() + + for (image_path, _), prob in zip(path_imgs, probs): + # 最初の4つはratingなので無視する + # # First 4 labels are actually ratings: pick one with argmax + # ratings_names = label_names[:4] + # rating_index = ratings_names["probs"].argmax() + # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] + + # それ以降はタグなのでconfidenceがthresholdより高いものを追加する + # Everything else is tags: pick any where prediction confidence > threshold + tag_text = "" + for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで + if p >= args.thresh: + tag_text += ", " + tags[i] + + if len(tag_text) > 0: + tag_text = tag_text[2:] # 最初の ", " を消す + + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: + f.write(tag_text + '\n') + if args.debug: + print(image_path, tag_text) + + b_imgs = [] + for image_path in tqdm(image_paths): + img = dbimutils.smart_imread(image_path) + img = dbimutils.smart_24bit(img) + img = dbimutils.make_square(img, IMAGE_SIZE) + img = dbimutils.smart_resize(img, IMAGE_SIZE) + img = img.astype(np.float32) + b_imgs.append((image_path, img)) + + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + if len(b_imgs) > 0: + run_batch(b_imgs) + + print("done!") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--model", type=str, default="networks/ViTB16_11_03_2022_07h05m53s", + help="model path to load / 読み込むモデルファイル") + parser.add_argument("--tag_csv", type=str, default="2022_0000_0899_6549/selected_tags.csv", + help="csv file for tags / タグ一覧のCSVファイル") + parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") + parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument("--debug", action="store_true", help="debug mode") + + args = parser.parse_args() + + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + main(args)