Update diffusers fine tuning code

This commit is contained in:
bmaltais 2022-11-25 21:30:18 -05:00
parent dd241f2142
commit 91eff4da17
6 changed files with 172 additions and 22 deletions

View File

@ -1,3 +1,10 @@
# Diffusers Fine Tuning # Diffusers Fine Tuning
This subfolder provide all the required toold to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29 This subfolder provide all the required tools to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29
## Releases
11/23 (v3):
- Added WD14Tagger tagging script.
- A log output function has been added to the fine_tune.py. Also, fixed the double shuffling of data.
- Fixed misspelling of options for each script (caption_extention→caption_extension will work for the time being, even if it remains outdated).

View File

@ -13,11 +13,13 @@ def clean_tags(image_key, tags):
# replace '_' to ' ' # replace '_' to ' '
tags = tags.replace('_', ' ') tags = tags.replace('_', ' ')
# remove rating # remove rating: deepdanbooruのみ
tokens = tags.split(", rating") tokens = tags.split(", rating")
if len(tokens) == 1: if len(tokens) == 1:
print("no rating:") # WD14 taggerのときはこちらになるのでメッセージは出さない
print(f"{image_key} {tags}") # print("no rating:")
# print(f"{image_key} {tags}")
pass
else: else:
if len(tokens) > 2: if len(tokens) > 2:
print("multiple ratings:") print("multiple ratings:")

View File

@ -1,5 +1,5 @@
# v2: select precision for saved checkpoint # v2: select precision for saved checkpoint
# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
# License: # License:
@ -22,12 +22,12 @@
import argparse import argparse
import itertools
import math import math
import os import os
import random import random
import json import json
import importlib import importlib
import time
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -159,7 +159,7 @@ class FineTuningDataset(torch.utils.data.Dataset):
input_ids = self.tokenizer(caption, padding="max_length", truncation=True, input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
# 77以上の時は "<CLS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<CLS>...<EOS>"の三連に変換する # 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
if self.tokenizer_max_length > self.tokenizer.model_max_length: if self.tokenizer_max_length > self.tokenizer.model_max_length:
input_ids = input_ids.squeeze(0) input_ids = input_ids.squeeze(0)
@ -242,7 +242,14 @@ def train(args):
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision) if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
# モデルを読み込む # モデルを読み込む
if use_stable_diffusion_format: if use_stable_diffusion_format:
@ -304,7 +311,7 @@ def train(args):
text_encoder.requires_grad_(False) # text encoderは学習しない text_encoder.requires_grad_(False) # text encoderは学習しない
text_encoder.eval() text_encoder.eval()
else: else:
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない
unet.requires_grad_(False) unet.requires_grad_(False)
unet.eval() unet.eval()
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
@ -314,7 +321,10 @@ def train(args):
for m in training_models: for m in training_models:
m.requires_grad_(True) m.requires_grad_(True)
params_to_optimize = itertools.chain(*[m.parameters() for m in training_models]) params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.") print("prepare optimizer, data loader etc.")
@ -337,11 +347,11 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers) train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps) "constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if fine_tuning: if fine_tuning:
@ -390,7 +400,7 @@ def train(args):
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(training_models[0]): with accelerator.accumulate(training_models[0]): # ここはこれでいいのか……?
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)
latents = latents * 0.18215 latents = latents * 0.18215
b_size = latents.shape[0] b_size = latents.shape[0]
@ -411,7 +421,7 @@ def train(args):
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None: if args.max_token_length is not None:
# <CLS>...<EOS> の三連を <CLS>...<EOS> へ戻す # <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)] sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
for i in range(1, args.max_token_length, tokenizer.model_max_length): for i in range(1, args.max_token_length, tokenizer.model_max_length):
sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2])
@ -436,7 +446,9 @@ def train(args):
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
params_to_clip = itertools.chain(*[m.parameters() for m in training_models]) params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
optimizer.step() optimizer.step()
@ -449,15 +461,22 @@ def train(args):
global_step += 1 global_step += 1
current_loss = loss.detach().item() * b_size current_loss = loss.detach().item() * b_size
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
loss_total += current_loss loss_total += current_loss
avr_loss = loss_total / (step+1) avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
# accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
@ -843,7 +862,8 @@ if __name__ == '__main__':
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上") help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--debug_dataset", action="store_true", parser.add_argument("--debug_dataset", action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup") parser.add_argument("--logging_dir", type=str, default=None,
parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.") help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)

View File

@ -48,7 +48,7 @@ def main(args):
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
for (image_path, _), caption in zip(path_imgs, captions): for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extention, "wt", encoding='utf-8') as f: with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n") f.write(caption + "\n")
if args.debug: if args.debug:
print(image_path, caption) print(image_path, caption)
@ -76,7 +76,9 @@ if __name__ == '__main__':
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("caption_weights", type=str, parser.add_argument("caption_weights", type=str,
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--beam_search", action="store_true", parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling") help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
@ -87,4 +89,9 @@ if __name__ == '__main__':
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args() args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args) main(args)

View File

@ -24,7 +24,7 @@ def main(args):
print("merge caption texts to metadata json.") print("merge caption texts to metadata json.")
for image_path in tqdm(image_paths): for image_path in tqdm(image_paths):
caption_path = os.path.splitext(image_path)[0] + args.caption_extention caption_path = os.path.splitext(image_path)[0] + args.caption_extension
with open(caption_path, "rt", encoding='utf-8') as f: with open(caption_path, "rt", encoding='utf-8') as f:
caption = f.readlines()[0].strip() caption = f.readlines()[0].strip()
@ -54,8 +54,15 @@ if __name__ == '__main__':
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption file / 読み込むキャプションファイルの拡張子") parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args() args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args) main(args)

View File

@ -0,0 +1,107 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import csv
import glob
import os
import json
from PIL import Image
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from Utils import dbimutils
# from wd14 tagger
IMAGE_SIZE = 448
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")
print("loading model and labels")
model = load_model(args.model)
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(args.tag_csv, "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
# 推論する
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh:
tag_text += ", " + tags[i]
if len(tag_text) > 0:
tag_text = tag_text[2:] # 最初の ", " を消す
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(tag_text + '\n')
if args.debug:
print(image_path, tag_text)
b_imgs = []
for image_path in tqdm(image_paths):
img = dbimutils.smart_imread(image_path)
img = dbimutils.smart_24bit(img)
img = dbimutils.make_square(img, IMAGE_SIZE)
img = dbimutils.smart_resize(img, IMAGE_SIZE)
img = img.astype(np.float32)
b_imgs.append((image_path, img))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--model", type=str, default="networks/ViTB16_11_03_2022_07h05m53s",
help="model path to load / 読み込むモデルファイル")
parser.add_argument("--tag_csv", type=str, default="2022_0000_0899_6549/selected_tags.csv",
help="csv file for tags / タグ一覧のCSVファイル")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)