Update diffusers fine tuning code

This commit is contained in:
bmaltais 2022-11-25 21:30:18 -05:00
parent dd241f2142
commit 91eff4da17
6 changed files with 172 additions and 22 deletions

View File

@ -1,3 +1,10 @@
# Diffusers Fine Tuning
This subfolder provide all the required toold to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29
This subfolder provide all the required tools to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29
## Releases
11/23 (v3):
- Added WD14Tagger tagging script.
- A log output function has been added to the fine_tune.py. Also, fixed the double shuffling of data.
- Fixed misspelling of options for each script (caption_extention→caption_extension will work for the time being, even if it remains outdated).

View File

@ -13,11 +13,13 @@ def clean_tags(image_key, tags):
# replace '_' to ' '
tags = tags.replace('_', ' ')
# remove rating
# remove rating: deepdanbooruのみ
tokens = tags.split(", rating")
if len(tokens) == 1:
print("no rating:")
print(f"{image_key} {tags}")
# WD14 taggerのときはこちらになるのでメッセージは出さない
# print("no rating:")
# print(f"{image_key} {tags}")
pass
else:
if len(tokens) > 2:
print("multiple ratings:")

View File

@ -1,5 +1,5 @@
# v2: select precision for saved checkpoint
# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
# License:
@ -22,12 +22,12 @@
import argparse
import itertools
import math
import os
import random
import json
import importlib
import time
from tqdm import tqdm
import torch
@ -159,7 +159,7 @@ class FineTuningDataset(torch.utils.data.Dataset):
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
# 77以上の時は "<CLS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<CLS>...<EOS>"の三連に変換する
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
if self.tokenizer_max_length > self.tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
@ -242,7 +242,14 @@ def train(args):
# acceleratorを準備する
print("prepare accelerator")
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
# モデルを読み込む
if use_stable_diffusion_format:
@ -304,7 +311,7 @@ def train(args):
text_encoder.requires_grad_(False) # text encoderは学習しない
text_encoder.eval()
else:
unet.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない
unet.requires_grad_(False)
unet.eval()
text_encoder.to(accelerator.device, dtype=weight_dtype)
@ -314,7 +321,10 @@ def train(args):
for m in training_models:
m.requires_grad_(True)
params_to_optimize = itertools.chain(*[m.parameters() for m in training_models])
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
@ -337,11 +347,11 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers)
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps)
"constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# acceleratorがなんかよろしくやってくれるらしい
if fine_tuning:
@ -390,7 +400,7 @@ def train(args):
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(training_models[0]):
with accelerator.accumulate(training_models[0]): # ここはこれでいいのか……?
latents = batch["latents"].to(accelerator.device)
latents = latents * 0.18215
b_size = latents.shape[0]
@ -411,7 +421,7 @@ def train(args):
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None:
# <CLS>...<EOS> の三連を <CLS>...<EOS> へ戻す
# <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
for i in range(1, args.max_token_length, tokenizer.model_max_length):
sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2])
@ -436,7 +446,9 @@ def train(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = itertools.chain(*[m.parameters() for m in training_models])
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
optimizer.step()
@ -449,15 +461,22 @@ def train(args):
global_step += 1
current_loss = loss.detach().item() * b_size
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
# accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
@ -843,7 +862,8 @@ if __name__ == '__main__':
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--debug_dataset", action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
args = parser.parse_args()
train(args)

View File

@ -48,7 +48,7 @@ def main(args):
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extention, "wt", encoding='utf-8') as f:
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
@ -76,7 +76,9 @@ if __name__ == '__main__':
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("caption_weights", type=str,
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
@ -87,4 +89,9 @@ if __name__ == '__main__':
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)

View File

@ -24,7 +24,7 @@ def main(args):
print("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = os.path.splitext(image_path)[0] + args.caption_extention
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
with open(caption_path, "rt", encoding='utf-8') as f:
caption = f.readlines()[0].strip()
@ -54,8 +54,15 @@ if __name__ == '__main__':
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)

View File

@ -0,0 +1,107 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import csv
import glob
import os
import json
from PIL import Image
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from Utils import dbimutils
# from wd14 tagger
IMAGE_SIZE = 448
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")
print("loading model and labels")
model = load_model(args.model)
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(args.tag_csv, "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
# 推論する
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh:
tag_text += ", " + tags[i]
if len(tag_text) > 0:
tag_text = tag_text[2:] # 最初の ", " を消す
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(tag_text + '\n')
if args.debug:
print(image_path, tag_text)
b_imgs = []
for image_path in tqdm(image_paths):
img = dbimutils.smart_imread(image_path)
img = dbimutils.smart_24bit(img)
img = dbimutils.make_square(img, IMAGE_SIZE)
img = dbimutils.smart_resize(img, IMAGE_SIZE)
img = img.astype(np.float32)
b_imgs.append((image_path, img))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--model", type=str, default="networks/ViTB16_11_03_2022_07h05m53s",
help="model path to load / 読み込むモデルファイル")
parser.add_argument("--tag_csv", type=str, default="2022_0000_0899_6549/selected_tags.csv",
help="csv file for tags / タグ一覧のCSVファイル")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)