Merge pull request #105 from bmaltais/dev

v20.6.0
This commit is contained in:
bmaltais 2023-02-04 08:37:25 -05:00 committed by GitHub
commit 2ed93b7a4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 944 additions and 396 deletions

View File

@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。bf1
cd sd-scripts cd sd-scripts
git pull git pull
.\venv\Scripts\activate .\venv\Scripts\activate
pip install --upgrade -r <requirement file name> pip install --upgrade -r requirements.txt
``` ```
コマンドが成功すれば新しいバージョンが使用できます。 コマンドが成功すれば新しいバージョンが使用できます。

View File

@ -143,6 +143,23 @@ Then redo the installation instruction within the kohya_ss venv.
## Change history ## Change history
* 2023/02/03
- Increase max LoRA rank (dim) size to 1024.
- Update finetune preprocessing scripts.
- ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev!
- The default weights of ``tag_images_by_wd14_tagger.py`` is now ``SmilingWolf/wd-v1-4-convnext-tagger-v2``. You can specify another model id from ``SmilingWolf`` by ``--repo_id`` option. Thanks to SmilingWolf for the great work.
- To change the weight, remove ``wd14_tagger_model`` folder, and run the script again.
- ``--max_data_loader_n_workers`` option is added to each script. This option uses the DataLoader for data loading to speed up loading, 20%~30% faster.
- Please specify 2 or 4, depends on the number of CPU cores.
- ``--recursive`` option is added to ``merge_dd_tags_to_metadata.py`` and ``merge_captions_to_metadata.py``, only works with ``--full_path``.
- ``make_captions_by_git.py`` is added. It uses [GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) for captioning.
- ``requirements.txt`` is updated. If you use this script, [please update the libraries](https://github.com/kohya-ss/sd-scripts#upgrade).
- Usage is almost the same as ``make_captions.py``, but batch size should be smaller.
- ``--remove_words`` option removes as much text as possible (such as ``the word "XXXX" on it``).
- ``--skip_existing`` option is added to ``prepare_buckets_latents.py``. Images with existing npz files are ignored by this option.
- ``clean_captions_and_tags.py`` is updated to remove duplicated or conflicting tags, e.g. ``shirt`` is removed when ``white shirt`` exists. if ``black hair`` is with ``red hair``, both are removed.
- Tag frequency is added to the metadata in ``train_network.py``. Thanks to space-nuko!
- __All tags and number of occurrences of the tag are recorded.__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
* 2023/01/30 (v20.5.2): * 2023/01/30 (v20.5.2):
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev! - Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev!
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py`` - Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py``

View File

@ -5,13 +5,32 @@ import argparse
import glob import glob
import os import os
import json import json
import re
from tqdm import tqdm from tqdm import tqdm
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
PATTERNS_REMOVE_IN_MULTI = [
PATTERN_HAIR_LENGTH,
PATTERN_HAIR_CUT,
re.compile(r', [\w\-]+ eyes, '),
re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
# 複数の髪型定義がある場合は削除する
re.compile(
r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
]
def clean_tags(image_key, tags): def clean_tags(image_key, tags):
# replace '_' to ' ' # replace '_' to ' '
tags = tags.replace('^_^', '^@@@^')
tags = tags.replace('_', ' ') tags = tags.replace('_', ' ')
tags = tags.replace('^@@@^', '^_^')
# remove rating: deepdanbooruのみ # remove rating: deepdanbooruのみ
tokens = tags.split(", rating") tokens = tags.split(", rating")
@ -26,6 +45,37 @@ def clean_tags(image_key, tags):
print(f"{image_key} {tags}") print(f"{image_key} {tags}")
tags = tokens[0] tags = tokens[0]
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
# 複数の人物がいる場合は髪色等のタグを削除する
if 'girls' in tags or 'boys' in tags:
for pat in PATTERNS_REMOVE_IN_MULTI:
found = pat.findall(tags)
if len(found) > 1: # 二つ以上、タグがある
tags = pat.sub("", tags)
# 髪の特殊対応
srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
if srch_hair_len:
org = srch_hair_len.group()
tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
found = PATTERN_HAIR.findall(tags)
if len(found) > 1:
tags = PATTERN_HAIR.sub("", tags)
if srch_hair_len:
tags = tags.replace(", @@@, ", org) # 戻す
# white shirtとshirtみたいな重複タグの削除
found = PATTERN_WORD.findall(tags)
for word in found:
if re.search(f", ((\w+) )+{word}, ", tags):
tags = tags.replace(f", {word}, ", "")
tags = tags.replace(", , ", ", ")
assert tags.startswith(", ") and tags.endswith(", ")
tags = tags[2:-2]
return tags return tags
@ -88,13 +138,23 @@ def main(args):
if tags is None: if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}") print(f"image does not have tags / メタデータにタグがありません: {image_key}")
else: else:
metadata[image_key]['tags'] = clean_tags(image_key, tags) org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
caption = metadata[image_key].get('caption') caption = metadata[image_key].get('caption')
if caption is None: if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else: else:
metadata[image_key]['caption'] = clean_caption(caption) org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") print(f"writing metadata: {args.out_json}")
@ -108,6 +168,7 @@ if __name__ == '__main__':
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--debug", action="store_true", help="debug mode")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
if len(unknown) == 1: if len(unknown) == 1:

View File

@ -11,18 +11,59 @@ import torch
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from blip.blip import blip_decoder from blip.blip import blip_decoder
# from Salesforce_BLIP.models.blip import blip_decoder import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_SIZE = 384
# 正方形でいいのか? という気がするがソースがそうなので
IMAGE_TRANSFORM = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# 共通化したいが微妙に処理が異なる……
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args): def main(args):
# fix the seed for reproducibility # fix the seed for reproducibility
seed = args.seed # + utils.get_rank() seed = args.seed # + utils.get_rank()
torch.manual_seed(seed) torch.manual_seed(seed)
np.random.seed(seed) np.random.seed(seed)
random.seed(seed) random.seed(seed)
if not os.path.exists("blip"): if not os.path.exists("blip"):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
@ -31,24 +72,15 @@ def main(args):
os.chdir('finetune') os.chdir('finetune')
print(f"load images from {args.train_data_dir}") print(f"load images from {args.train_data_dir}")
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ image_paths = train_util.glob_images(args.train_data_dir)
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}") print(f"loading BLIP caption: {args.caption_weights}")
image_size = 384 model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
model.eval() model.eval()
model = model.to(DEVICE) model = model.to(DEVICE)
print("BLIP loaded") print("BLIP loaded")
# 正方形でいいのか? という気がするがソースがそうなので
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# captioningする # captioningする
def run_batch(path_imgs): def run_batch(path_imgs):
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
@ -66,18 +98,35 @@ def main(args):
if args.debug: if args.debug:
print(image_path, caption) print(image_path, caption)
b_imgs = [] # 読み込みの高速化のためにDataLoaderを使うオプション
for image_path in tqdm(image_paths, smoothing=0.0): if args.max_data_loader_n_workers is not None:
raw_image = Image.open(image_path) dataset = ImageLoadingTransformDataset(image_paths)
if raw_image.mode != "RGB": data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
print(f"convert image mode {raw_image.mode} to RGB: {image_path}") num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
raw_image = raw_image.convert("RGB") else:
data = [[(None, ip)] for ip in image_paths]
image = transform(raw_image) b_imgs = []
b_imgs.append((image_path, image)) for data_entry in tqdm(data, smoothing=0.0):
if len(b_imgs) >= args.batch_size: for data in data_entry:
run_batch(b_imgs) if data is None:
b_imgs.clear() continue
img_tensor, image_path = data
if img_tensor is None:
try:
raw_image = Image.open(image_path)
if raw_image.mode != 'RGB':
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, img_tensor))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0: if len(b_imgs) > 0:
run_batch(b_imgs) run_batch(b_imgs)
@ -95,6 +144,8 @@ if __name__ == '__main__':
parser.add_argument("--beam_search", action="store_true", parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling") help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる") parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")

View File

@ -0,0 +1,145 @@
import argparse
import os
import re
from PIL import Image
from tqdm import tqdm
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
re.compile(r'with the words "'),
re.compile(r'word \w+ on it'),
re.compile(r'that says the word \w+ on it'),
re.compile('that says\'the word "( on it)?'),
]
# 誤検知しまくりの with the word xxxx を消す
def remove_words(captions, debug):
removed_caps = []
for caption in captions:
cap = caption
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
removed_caps.append(cap)
return removed_caps
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
# input_idsがバッチサイズと同じ件数である必要があるバッチサイズはこの関数から参照できないので外から渡す
# ここより上で置き換えようとするとすごく大変
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
if input_ids.size()[0] != curr_batch_size[0]:
input_ids = input_ids.repeat(curr_batch_size[0], 1)
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
# captioningする
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
if args.remove_words:
captions = remove_words(captions, args.debug)
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument("--remove_words", action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
main(args)

View File

@ -1,26 +1,24 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse import argparse
import glob
import os
import json import json
from pathlib import Path
from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if args.in_json is None and os.path.isfile(args.out_json): if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json args.in_json = args.out_json
if args.in_json is not None: if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}") print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
metadata = json.load(f)
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else: else:
print("new metadata will be created / 新しいメタデータファイルが作成されます") print("new metadata will be created / 新しいメタデータファイルが作成されます")
@ -28,11 +26,10 @@ def main(args):
print("merge caption texts to metadata json.") print("merge caption texts to metadata json.")
for image_path in tqdm(image_paths): for image_path in tqdm(image_paths):
caption_path = os.path.splitext(image_path)[0] + args.caption_extension caption_path = image_path.with_suffix(args.caption_extension)
with open(caption_path, "rt", encoding='utf-8') as f: caption = caption_path.read_text(encoding='utf-8').strip()
caption = f.readlines()[0].strip()
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}
@ -42,8 +39,7 @@ def main(args):
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f: Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
json.dump(metadata, f, indent=2)
print("done!") print("done!")
@ -51,12 +47,15 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む") parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--caption_extention", type=str, default=None, parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--full_path", action="store_true", parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,26 +1,24 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse import argparse
import glob
import os
import json import json
from pathlib import Path
from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if args.in_json is None and os.path.isfile(args.out_json): if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json args.in_json = args.out_json
if args.in_json is not None: if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}") print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
metadata = json.load(f)
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else: else:
print("new metadata will be created / 新しいメタデータファイルが作成されます") print("new metadata will be created / 新しいメタデータファイルが作成されます")
@ -28,11 +26,10 @@ def main(args):
print("merge tags to metadata json.") print("merge tags to metadata json.")
for image_path in tqdm(image_paths): for image_path in tqdm(image_paths):
tags_path = os.path.splitext(image_path)[0] + '.txt' tags_path = image_path.with_suffix(args.caption_extension)
with open(tags_path, "rt", encoding='utf-8') as f: tags = tags_path.read_text(encoding='utf-8').strip()
tags = f.readlines()[0].strip()
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}
@ -42,8 +39,8 @@ def main(args):
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f: Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
json.dump(metadata, f, indent=2)
print("done!") print("done!")
@ -51,9 +48,14 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む") parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--full_path", action="store_true", parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--caption_extension", type=str, default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags") parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,20 +1,16 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse import argparse
import glob
import os import os
import json import json
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from diffusers import AutoencoderKL
from PIL import Image from PIL import Image
import cv2 import cv2
import torch import torch
from torchvision import transforms from torchvision import transforms
import library.model_util as model_util import library.model_util as model_util
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -26,6 +22,16 @@ IMAGE_TRANSFORMS = transforms.Compose(
) )
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def get_latents(vae, images, weight_dtype): def get_latents(vae, images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images] img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors) img_tensors = torch.stack(img_tensors)
@ -35,9 +41,18 @@ def get_latents(vae, images, weight_dtype):
return latents return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
else:
base_name = image_key
if flip:
base_name += '_flip'
return os.path.join(data_dir, base_name)
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ image_paths = train_util.glob_images(args.train_data_dir)
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json): if os.path.exists(args.in_json):
@ -70,15 +85,56 @@ def main(args):
buckets_imgs = [[] for _ in range(len(bucket_resos))] buckets_imgs = [[] for _ in range(len(bucket_resos))]
bucket_counts = [0 for _ in range(len(bucket_resos))] bucket_counts = [0 for _ in range(len(bucket_resos))]
img_ar_errors = [] img_ar_errors = []
for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)):
def process_batch(is_last):
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
for (image_key, _, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
np.savez(npz_file_name, latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
np.savez(npz_file_name, latent)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
for data_entry in tqdm(data, smoothing=0.0):
if data_entry[0] is None:
continue
img_tensor, image_path = data_entry[0]
if img_tensor is not None:
image = transforms.functional.to_pil_image(img_tensor)
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}
image = Image.open(image_path) # 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変
if image.mode != 'RGB':
image = image.convert("RGB")
aspect_ratio = image.width / image.height aspect_ratio = image.width / image.height
ar_errors = bucket_aspect_ratios - aspect_ratio ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin() bucket_id = np.abs(ar_errors).argmin()
@ -102,6 +158,25 @@ def main(args):
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break
dat = np.load(npz_file)['arr_0']
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
continue
# 画像をリサイズしてトリミングする # 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で…… # PILにinter_areaがないのでcv2で……
image = np.array(image) image = np.array(image)
@ -123,25 +198,10 @@ def main(args):
metadata[image_key]['train_resolution'] = reso metadata[image_key]['train_resolution'] = reso
# バッチを推論するか判定して推論する # バッチを推論するか判定して推論する
is_last = i == len(image_paths) - 1 process_batch(False)
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
for (image_key, reso, _), latent in zip(bucket, latents): # 残りを処理する
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key process_batch(True)
np.savez(os.path.join(args.train_data_dir, npz_file_name), latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent)
bucket.clear()
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)): for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
print(f"bucket {i} {reso}: {count}") print(f"bucket {i} {reso}: {count}")
@ -162,8 +222,10 @@ if __name__ == '__main__':
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_resolution", type=str, default="512,512", parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
@ -174,6 +236,8 @@ if __name__ == '__main__':
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--flip_aug", action="store_true", parser.add_argument("--flip_aug", action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1,180 +0,0 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from tqdm import tqdm
import numpy as np
from diffusers import AutoencoderKL
from PIL import Image
import cv2
import torch
from torchvision import transforms
import library.model_util as model_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def get_latents(vae, images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors)
img_tensors = img_tensors.to(DEVICE, weight_dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
return latents
def main(args):
# image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
# glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
# print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
vae.eval()
vae.to(DEVICE, dtype=weight_dtype)
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
max_reso, args.min_bucket_reso, args.max_bucket_reso)
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
buckets_imgs = [[] for _ in range(len(bucket_resos))]
bucket_counts = [0 for _ in range(len(bucket_resos))]
img_ar_errors = []
for i, image_path in enumerate(tqdm(metadata, smoothing=0.0)):
image_key = image_path
if image_key not in metadata:
metadata[image_key] = {}
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
aspect_ratio = image.width / image.height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
reso = bucket_resos[bucket_id]
ar_error = ar_errors[bucket_id]
img_ar_errors.append(abs(ar_error))
# どのサイズにリサイズするか→トリミングする方向で
if ar_error <= 0: # 横が長い→縦を合わせる
scale = reso[1] / image.height
else:
scale = reso[0] / image.width
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
assert resized_size[0] == reso[0] or resized_size[1] == reso[
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size//2:trim_size//2 + reso[0]]
elif resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size//2:trim_size//2 + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
# # debug
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
# バッチへ追加
buckets_imgs[bucket_id].append((image_key, reso, image))
bucket_counts[bucket_id] += 1
metadata[image_key]['train_resolution'] = reso
# バッチを推論するか判定して推論する
is_last = i == len(metadata) - 1
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name), latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name + '_flip'), latent)
bucket.clear()
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
print(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--flip_aug", action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
args = parser.parse_args()
main(args)

View File

@ -1,6 +1,3 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse import argparse
import csv import csv
import glob import glob
@ -12,32 +9,87 @@ from tqdm import tqdm
import numpy as np import numpy as np
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import torch
import library.train_util as train_util
# from wd14 tagger # from wd14 tagger
IMAGE_SIZE = 448 IMAGE_SIZE = 448
WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger' # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables" SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1] CSV_FILE = FILES[-1]
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args): def main(args):
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時 # depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download: if not os.path.exists(args.model_dir) or args.force_download:
print("downloading wd14 tagger model from hf_hub") print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES: for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES: for file in SUB_DIR_FILES:
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
args.model_dir, SUB_DIR), force_download=True, force_filename=file) args.model_dir, SUB_DIR), force_download=True, force_filename=file)
else:
print("using existing wd14 tagger model")
# 画像を読み込む # 画像を読み込む
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ image_paths = train_util.glob_images(args.train_data_dir)
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
print("loading model and labels") print("loading model and labels")
@ -72,7 +124,7 @@ def main(args):
# Everything else is tags: pick any where prediction confidence > threshold # Everything else is tags: pick any where prediction confidence > threshold
tag_text = "" tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh: if p >= args.thresh and i < len(tags):
tag_text += ", " + tags[i] tag_text += ", " + tags[i]
if len(tag_text) > 0: if len(tag_text) > 0:
@ -83,34 +135,37 @@ def main(args):
if args.debug: if args.debug:
print(image_path, tag_text) print(image_path, tag_text)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = [] b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0): for data_entry in tqdm(data, smoothing=0.0):
img = Image.open(image_path) # cv2は日本語ファイル名で死ぬのとモード変換したいのでpillowで開く for data in data_entry:
if img.mode != 'RGB': if data is None:
img = img.convert("RGB") continue
img = np.array(img)
img = img[:, :, ::-1] # RGB->BGR
# pad to square image, image_path = data
size = max(img.shape[0:2]) if image is not None:
pad_x = size - img.shape[1] image = image.detach().numpy()
pad_y = size - img.shape[0] else:
pad_l = pad_x // 2 try:
pad_t = pad_y // 2 image = Image.open(image_path)
img = np.pad(img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) if image.mode != 'RGB':
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 if len(b_imgs) >= args.batch_size:
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) run_batch(b_imgs)
# cv2.imshow("img", img) b_imgs.clear()
# cv2.waitKey()
# cv2.destroyAllWindows()
img = img.astype(np.float32)
b_imgs.append((image_path, img))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0: if len(b_imgs) > 0:
run_batch(b_imgs) run_batch(b_imgs)
@ -121,7 +176,7 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=WD14_TAGGER_REPO, parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
@ -129,6 +184,8 @@ if __name__ == '__main__':
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--caption_extention", type=str, default=None, parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")

View File

@ -292,7 +292,7 @@ def train_model(
subprocess.run(run_cmd) subprocess.run(run_cmd)
image_num = len( image_num = len(
[f for f in os.listdir(image_folder) if f.endswith('.npz')] [f for f in os.listdir(image_folder) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')]
) )
print(f'image_num = {image_num}') print(f'image_num = {image_num}')

View File

@ -470,6 +470,9 @@ class PipelineLike():
self.scheduler = scheduler self.scheduler = scheduler
self.safety_checker = None self.safety_checker = None
# Textual Inversion
self.token_replacements = {}
# CLIP guidance # CLIP guidance
self.clip_guidance_scale = clip_guidance_scale self.clip_guidance_scale = clip_guidance_scale
self.clip_image_guidance_scale = clip_image_guidance_scale self.clip_image_guidance_scale = clip_image_guidance_scale
@ -484,6 +487,19 @@ class PipelineLike():
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers) self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
# Textual Inversion
def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids
def replace_token(self, tokens):
new_tokens = []
for token in tokens:
if token in self.token_replacements:
new_tokens.extend(self.token_replacements[token])
else:
new_tokens.append(token)
return new_tokens
# region xformersとか使う部分独自に書き換えるので関係なし # region xformersとか使う部分独自に書き換えるので関係なし
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self):
r""" r"""
@ -1507,6 +1523,9 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
for word, weight in texts_and_weights: for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token # tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1] token = pipe.tokenizer(word).input_ids[1:-1]
token = pipe.replace_token(token)
text_token += token text_token += token
# copy the weight by length of token # copy the weight by length of token
text_weight += [weight] * len(token) text_weight += [weight] * len(token)
@ -1826,12 +1845,12 @@ def main(args):
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else: else:
print("load Diffusers pretrained models") print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = pipe.text_encoder text_encoder = loading_pipe.text_encoder
vae = pipe.vae vae = loading_pipe.vae
unet = pipe.unet unet = loading_pipe.unet
tokenizer = pipe.tokenizer tokenizer = loading_pipe.tokenizer
del pipe del loading_pipe
# VAEを読み込む # VAEを読み込む
if args.vae is not None: if args.vae is not None:
@ -2039,6 +2058,44 @@ def main(args):
if args.diffusers_xformers: if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
# Textual Inversionを処理する
if args.textual_inversion_embeddings:
token_ids_embeds = []
for embeds_file in args.textual_inversion_embeddings:
if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file
data = load_file(embeds_file)
else:
data = torch.load(embeds_file, map_location="cpu")
embeds = next(iter(data.values()))
if type(embeds) != torch.Tensor:
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
num_vectors_per_token = embeds.size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens = tokenizer.add_tokens(token_strings)
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
if num_vectors_per_token > 1:
pipe.add_token_replacement(token_ids[0], token_ids)
token_ids_embeds.append((token_ids, embeds))
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
for token_ids, embeds in token_ids_embeds:
for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed
# promptを取得する # promptを取得する
if args.from_file is not None: if args.from_file is not None:
print(f"reading prompts from {args.from_file}") print(f"reading prompts from {args.from_file}")
@ -2157,8 +2214,8 @@ def main(args):
os.makedirs(args.outdir, exist_ok=True) os.makedirs(args.outdir, exist_ok=True)
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for iter in range(args.n_iter): for gen_iter in range(args.n_iter):
print(f"iteration {iter+1}/{args.n_iter}") print(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7fffffff) iter_seed = random.randint(0, 0x7fffffff)
# バッチ処理の関数 # バッチ処理の関数
@ -2527,6 +2584,8 @@ if __name__ == '__main__':
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_args", type=str, default=None, nargs='*', parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
parser.add_argument("--max_embeddings_multiples", type=int, default=None, parser.add_argument("--max_embeddings_multiples", type=int, default=None,
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる') help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')

View File

@ -19,7 +19,7 @@ def UI(username, password):
print('Load CSS...') print('Load CSS...')
css += file.read() + '\n' css += file.read() + '\n'
interface = gr.Blocks(css=css) interface = gr.Blocks(css=css, title="Kohya_ss GUI")
with interface: with interface:
with gr.Tab('Dreambooth'): with gr.Tab('Dreambooth'):

View File

@ -109,11 +109,11 @@ def gradio_extract_lora_tab():
) )
with gr.Row(): with gr.Row():
dim = gr.Slider( dim = gr.Slider(
minimum=1, minimum=4,
maximum=128, maximum=1024,
label='Network Dimension', label='Network Dimension',
value=8, value=128,
step=1, step=4,
interactive=True, interactive=True,
) )
v2 = gr.Checkbox(label='v2', value=False, interactive=True) v2 = gr.Checkbox(label='v2', value=False, interactive=True)

126
library/git_caption_gui.py Normal file
View File

@ -0,0 +1,126 @@
import gradio as gr
from easygui import msgbox
import subprocess
import os
from .common_gui import get_folder_path, add_pre_postfix
def caption_images(
train_data_dir,
caption_ext,
batch_size,
max_data_loader_n_workers,
max_length,
model_id,
prefix,
postfix,
):
# Check for images_dir_input
if train_data_dir == '':
msgbox('Image folder is missing...')
return
if caption_ext == '':
msgbox('Please provide an extension for the caption files.')
return
print(f'GIT captioning files in {train_data_dir}...')
run_cmd = f'.\\venv\\Scripts\\python.exe "finetune/make_captions.py"'
if not model_id == '':
run_cmd += f' --model_id="{model_id}"'
run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
run_cmd += f' --max_length="{int(max_length)}"'
if caption_ext != '':
run_cmd += f' --caption_extension="{caption_ext}"'
run_cmd += f' "{train_data_dir}"'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
# Add prefix and postfix
add_pre_postfix(
folder=train_data_dir,
caption_file_ext=caption_ext,
prefix=prefix,
postfix=postfix,
)
print('...captioning done')
###
# Gradio UI
###
def gradio_git_caption_gui_tab():
with gr.Tab('GIT Captioning'):
gr.Markdown(
'This utility will use GIT to caption files for each images in a folder.'
)
with gr.Row():
train_data_dir = gr.Textbox(
label='Image folder to caption',
placeholder='Directory containing the images to caption',
interactive=True,
)
button_train_data_dir_input = gr.Button(
'📂', elem_id='open_folder_small'
)
button_train_data_dir_input.click(
get_folder_path, outputs=train_data_dir
)
with gr.Row():
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extention for caption file. eg: .caption, .txt',
value='.txt',
interactive=True,
)
prefix = gr.Textbox(
label='Prefix to add to BLIP caption',
placeholder='(Optional)',
interactive=True,
)
postfix = gr.Textbox(
label='Postfix to add to BLIP caption',
placeholder='(Optional)',
interactive=True,
)
batch_size = gr.Number(
value=1, label='Batch size', interactive=True
)
with gr.Row():
max_data_loader_n_workers = gr.Number(
value=2, label='Number of workers', interactive=True
)
max_length = gr.Number(
value=75, label='Max length', interactive=True
)
model_id = gr.Textbox(
label="Model",
placeholder="(Optional) model id for GIT in Hugging Face", interactive=True
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[
train_data_dir,
caption_ext,
batch_size,
max_data_loader_n_workers,
max_length,
model_id,
prefix,
postfix,
],
)

View File

@ -45,6 +45,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset # region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
class ImageInfo(): class ImageInfo():
@ -87,6 +88,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.enable_bucket = False self.enable_bucket = False
self.min_bucket_reso = None self.min_bucket_reso = None
self.max_bucket_reso = None self.max_bucket_reso = None
self.tag_frequency = {}
self.bucket_info = None self.bucket_info = None
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@ -115,6 +117,16 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {} self.replacements = {}
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
if tag and not tag.isspace():
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
def disable_token_padding(self): def disable_token_padding(self):
self.token_padding_disabled = True self.token_padding_disabled = True
@ -140,7 +152,7 @@ class BaseDataset(torch.utils.data.Dataset):
if type(str_to) == list: if type(str_to) == list:
caption = random.choice(str_to) caption = random.choice(str_to)
else: else:
caption = str_to caption = str_to
else: else:
caption = caption.replace(str_from, str_to) caption = caption.replace(str_from, str_to)
@ -240,13 +252,14 @@ class BaseDataset(torch.utils.data.Dataset):
print("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む") print("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む")
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)} self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") # only show bucket info if there is an actual image in it
if len(img_keys) > 0:
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
img_ar_errors = np.array(img_ar_errors) img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors)) mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}") print(f"mean ar error (without repeats): {mean_img_ar_error}")
# 参照用indexを作る # 参照用indexを作る
self.buckets_indices: list(BucketBatchIndex) = [] self.buckets_indices: list(BucketBatchIndex) = []
@ -545,6 +558,8 @@ class DreamBoothDataset(BaseDataset):
cap_for_img = read_caption(img_path) cap_for_img = read_caption(img_path)
captions.append(caption_by_folder if cap_for_img is None else cap_for_img) captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
return n_repeats, img_paths, captions return n_repeats, img_paths, captions
print("prepare train images.") print("prepare train images.")
@ -553,10 +568,13 @@ class DreamBoothDataset(BaseDataset):
for dir in train_dirs: for dir in train_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
num_train_images += n_repeats * len(img_paths) num_train_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path) info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info) self.register_image(info)
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.") print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images self.num_train_images = num_train_images
@ -570,9 +588,11 @@ class DreamBoothDataset(BaseDataset):
for dir in reg_dirs: for dir in reg_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
num_reg_images += n_repeats * len(img_paths) num_reg_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path) info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info) reg_infos.append(info)
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.") print(f"{num_reg_images} reg images.")
@ -617,6 +637,7 @@ class FineTuningDataset(BaseDataset):
self.train_data_dir = train_data_dir self.train_data_dir = train_data_dir
self.batch_size = batch_size self.batch_size = batch_size
tags_list = []
for image_key, img_md in metadata.items(): for image_key, img_md in metadata.items():
# path情報を作る # path情報を作る
if os.path.exists(image_key): if os.path.exists(image_key):
@ -633,6 +654,7 @@ class FineTuningDataset(BaseDataset):
caption = tags caption = tags
elif tags is not None and len(tags) > 0: elif tags is not None and len(tags) > 0:
caption = caption + ', ' + tags caption = caption + ', ' + tags
tags_list.append(tags)
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
@ -646,7 +668,8 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = len(metadata) * dataset_repeats self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0 self.num_reg_images = 0
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
# check existence of all npz files # check existence of all npz files
if not self.color_aug: if not self.color_aug:
@ -667,6 +690,8 @@ class FineTuningDataset(BaseDataset):
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
elif not npz_all: elif not npz_all:
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
if self.flip_aug:
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
for image_info in self.image_data.values(): for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None image_info.latents_npz = image_info.latents_npz_flipped = None
@ -756,15 +781,30 @@ def debug_dataset(train_dataset, show_input_ids=False):
break break
def glob_images(dir, base): def glob_images(directory, base="*"):
img_paths = [] img_paths = []
for ext in IMAGE_EXTENSIONS: for ext in IMAGE_EXTENSIONS:
if base == '*': if base == '*':
img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext))) img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else: else:
img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext)))) img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort()
return img_paths return img_paths
def glob_images_pathlib(dir_path, recursive):
image_paths = []
if recursive:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.rglob('*' + ext))
else:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob('*' + ext))
# image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort()
return image_paths
# endregion # endregion
@ -1495,5 +1535,30 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
# endregion
# region 前処理用
class ImageLoadingDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor_pil = transforms.functional.pil_to_tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor_pil, img_path)
# endregion # endregion

View File

@ -9,6 +9,7 @@ import argparse
from library.basic_caption_gui import gradio_basic_caption_gui_tab from library.basic_caption_gui import gradio_basic_caption_gui_tab
from library.convert_model_gui import gradio_convert_model_tab from library.convert_model_gui import gradio_convert_model_tab
from library.blip_caption_gui import gradio_blip_caption_gui_tab from library.blip_caption_gui import gradio_blip_caption_gui_tab
from library.git_caption_gui import gradio_git_caption_gui_tab
from library.wd14_caption_gui import gradio_wd14_caption_gui_tab from library.wd14_caption_gui import gradio_wd14_caption_gui_tab
@ -23,6 +24,7 @@ def utilities_tab(
with gr.Tab('Captioning'): with gr.Tab('Captioning'):
gradio_basic_caption_gui_tab() gradio_basic_caption_gui_tab()
gradio_blip_caption_gui_tab() gradio_blip_caption_gui_tab()
gradio_git_caption_gui_tab()
gradio_wd14_caption_gui_tab() gradio_wd14_caption_gui_tab()
gradio_convert_model_tab() gradio_convert_model_tab()

View File

@ -291,11 +291,11 @@ def train_model(
if unet_lr == '': if unet_lr == '':
unet_lr = 0 unet_lr = 0
if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0): # if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
msgbox( # msgbox(
'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided' # 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
) # )
return # return
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir
subfolders = [ subfolders = [
@ -383,15 +383,26 @@ def train_model(
if not float(prior_loss_weight) == 1.0: if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --prior_loss_weight={prior_loss_weight}'
run_cmd += f' --network_module=networks.lora' run_cmd += f' --network_module=networks.lora'
if not float(text_encoder_lr) == 0:
run_cmd += f' --text_encoder_lr={text_encoder_lr}' if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
run_cmd += f' --unet_lr={unet_lr}'
elif not (float(text_encoder_lr) == 0):
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
run_cmd += f' --network_train_text_encoder_only'
else:
run_cmd += f' --unet_lr={unet_lr}'
run_cmd += f' --network_train_unet_only'
else: else:
run_cmd += f' --network_train_unet_only' if float(text_encoder_lr) == 0:
if not float(unet_lr) == 0: msgbox(
run_cmd += f' --unet_lr={unet_lr}' 'Please input learning rate values.'
else: )
run_cmd += f' --network_train_text_encoder_only' return
run_cmd += f' --network_dim={network_dim}' run_cmd += f' --network_dim={network_dim}'
if not lora_network_weights == '': if not lora_network_weights == '':
run_cmd += f' --network_weights="{lora_network_weights}"' run_cmd += f' --network_weights="{lora_network_weights}"'
if int(gradient_accumulation_steps) > 1: if int(gradient_accumulation_steps) > 1:
@ -400,6 +411,8 @@ def train_model(
run_cmd += f' --output_name="{output_name}"' run_cmd += f' --output_name="{output_name}"'
if not lr_scheduler_num_cycles == '': if not lr_scheduler_num_cycles == '':
run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"' run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"'
else:
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
if not lr_scheduler_power == '': if not lr_scheduler_power == '':
run_cmd += f' --output_name="{lr_scheduler_power}"' run_cmd += f' --output_name="{lr_scheduler_power}"'
@ -612,19 +625,19 @@ def lora_tab(
placeholder='Optional', placeholder='Optional',
) )
network_dim = gr.Slider( network_dim = gr.Slider(
minimum=1, minimum=4,
maximum=128, maximum=1024,
label='Network Rank (Dimension)', label='Network Rank (Dimension)',
value=8, value=8,
step=1, step=4,
interactive=True, interactive=True,
) )
network_alpha = gr.Slider( network_alpha = gr.Slider(
minimum=1, minimum=4,
maximum=128, maximum=1024,
label='Network Alpha', label='Network Alpha',
value=1, value=1,
step=1, step=4,
interactive=True, interactive=True,
) )
with gr.Row(): with gr.Row():

View File

@ -1,5 +1,5 @@
accelerate==0.15.0 accelerate==0.15.0
transformers==4.25.1 transformers==4.26.0
ftfy ftfy
albumentations albumentations
opencv-python opencv-python
@ -9,7 +9,7 @@ pytorch_lightning
bitsandbytes==0.35.0 bitsandbytes==0.35.0
tensorboard tensorboard
safetensors==0.2.6 safetensors==0.2.6
gradio==3.15.0 gradio==3.16.2
altair altair
easygui easygui
tk tk

View File

@ -0,0 +1,66 @@
import os
import cv2
import argparse
import shutil
import math
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Create destination folder if it does not exist
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Iterate through all files in src_img_folder
for filename in os.listdir(src_img_folder):
# Check if the image is png, jpg or webp
if not filename.endswith(('.png', '.jpg', '.webp')):
# Copy the file to the destination folder if not png, jpg or webp
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
continue
# Load image
img = cv2.imread(os.path.join(src_img_folder, filename))
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
# Check if the image needs resizing
if current_pixels > max_pixels:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height))
# Calculate the new height and width that are divisible by divisible_by
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
# Save resized image in dst_img_folder
cv2.imwrite(os.path.join(dst_img_folder, filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]}")
def main():
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
parser.add_argument('--max_resolution', type=str, help='Maximum resolution in the format "512x512"', default="512x512")
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=2)
args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
if __name__ == '__main__':
main()

View File

@ -1,3 +1,6 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer
from typing import Optional, Union
import importlib import importlib
import argparse import argparse
import gc import gc
@ -40,9 +43,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
# Which is a newer release of diffusers than currently packaged with sd-scripts # Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
from typing import Optional, Union
from torch.optim import Optimizer
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
def get_scheduler_fix( def get_scheduler_fix(
name: Union[str, SchedulerType], name: Union[str, SchedulerType],
@ -52,53 +52,53 @@ def get_scheduler_fix(
num_cycles: int = 1, num_cycles: int = 1,
power: float = 1.0, power: float = 1.0,
): ):
""" """
Unified API to get any scheduler from its name. Unified API to get any scheduler from its name.
Args: Args:
name (`str` or `SchedulerType`): name (`str` or `SchedulerType`):
The name of the scheduler to use. The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`): optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training. The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*): num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it. optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*): num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it. optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*): num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0): power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1): last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training. The index of the last epoch when resuming training.
""" """
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT: if name == SchedulerType.CONSTANT:
return schedule_func(optimizer) return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps` # All other schedulers require `num_warmup_steps`
if num_warmup_steps is None: if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP: if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps` # All other schedulers require `num_training_steps`
if num_training_steps is None: if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS: if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func( return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
) )
if name == SchedulerType.POLYNOMIAL: if name == SchedulerType.POLYNOMIAL:
return schedule_func( return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
) )
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def train(args): def train(args):
@ -135,7 +135,7 @@ def train(args):
train_util.debug_dataset(train_dataset) train_util.debug_dataset(train_dataset)
return return
if len(train_dataset) == 0: if len(train_dataset) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります")
return return
# acceleratorを準備する # acceleratorを準備する
@ -224,7 +224,7 @@ def train(args):
# lr schedulerを用意する # lr schedulerを用意する
# lr_scheduler = diffusers.optimization.get_scheduler( # lr_scheduler = diffusers.optimization.get_scheduler(
lr_scheduler = get_scheduler_fix( lr_scheduler = get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
@ -335,6 +335,7 @@ def train(args):
"ss_keep_tokens": args.keep_tokens, "ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
"ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment # will not be updated after training "ss_training_comment": args.training_comment # will not be updated after training
} }