Merge pull request #105 from bmaltais/dev

v20.6.0
This commit is contained in:
bmaltais 2023-02-04 08:37:25 -05:00 committed by GitHub
commit 2ed93b7a4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 944 additions and 396 deletions

View File

@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。bf1
cd sd-scripts
git pull
.\venv\Scripts\activate
pip install --upgrade -r <requirement file name>
pip install --upgrade -r requirements.txt
```
コマンドが成功すれば新しいバージョンが使用できます。

View File

@ -143,6 +143,23 @@ Then redo the installation instruction within the kohya_ss venv.
## Change history
* 2023/02/03
- Increase max LoRA rank (dim) size to 1024.
- Update finetune preprocessing scripts.
- ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev!
- The default weights of ``tag_images_by_wd14_tagger.py`` is now ``SmilingWolf/wd-v1-4-convnext-tagger-v2``. You can specify another model id from ``SmilingWolf`` by ``--repo_id`` option. Thanks to SmilingWolf for the great work.
- To change the weight, remove ``wd14_tagger_model`` folder, and run the script again.
- ``--max_data_loader_n_workers`` option is added to each script. This option uses the DataLoader for data loading to speed up loading, 20%~30% faster.
- Please specify 2 or 4, depends on the number of CPU cores.
- ``--recursive`` option is added to ``merge_dd_tags_to_metadata.py`` and ``merge_captions_to_metadata.py``, only works with ``--full_path``.
- ``make_captions_by_git.py`` is added. It uses [GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) for captioning.
- ``requirements.txt`` is updated. If you use this script, [please update the libraries](https://github.com/kohya-ss/sd-scripts#upgrade).
- Usage is almost the same as ``make_captions.py``, but batch size should be smaller.
- ``--remove_words`` option removes as much text as possible (such as ``the word "XXXX" on it``).
- ``--skip_existing`` option is added to ``prepare_buckets_latents.py``. Images with existing npz files are ignored by this option.
- ``clean_captions_and_tags.py`` is updated to remove duplicated or conflicting tags, e.g. ``shirt`` is removed when ``white shirt`` exists. if ``black hair`` is with ``red hair``, both are removed.
- Tag frequency is added to the metadata in ``train_network.py``. Thanks to space-nuko!
- __All tags and number of occurrences of the tag are recorded.__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
* 2023/01/30 (v20.5.2):
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev!
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py``

View File

@ -5,13 +5,32 @@ import argparse
import glob
import os
import json
import re
from tqdm import tqdm
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
PATTERNS_REMOVE_IN_MULTI = [
PATTERN_HAIR_LENGTH,
PATTERN_HAIR_CUT,
re.compile(r', [\w\-]+ eyes, '),
re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
# 複数の髪型定義がある場合は削除する
re.compile(
r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
]
def clean_tags(image_key, tags):
# replace '_' to ' '
tags = tags.replace('^_^', '^@@@^')
tags = tags.replace('_', ' ')
tags = tags.replace('^@@@^', '^_^')
# remove rating: deepdanbooruのみ
tokens = tags.split(", rating")
@ -26,6 +45,37 @@ def clean_tags(image_key, tags):
print(f"{image_key} {tags}")
tags = tokens[0]
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
# 複数の人物がいる場合は髪色等のタグを削除する
if 'girls' in tags or 'boys' in tags:
for pat in PATTERNS_REMOVE_IN_MULTI:
found = pat.findall(tags)
if len(found) > 1: # 二つ以上、タグがある
tags = pat.sub("", tags)
# 髪の特殊対応
srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
if srch_hair_len:
org = srch_hair_len.group()
tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
found = PATTERN_HAIR.findall(tags)
if len(found) > 1:
tags = PATTERN_HAIR.sub("", tags)
if srch_hair_len:
tags = tags.replace(", @@@, ", org) # 戻す
# white shirtとshirtみたいな重複タグの削除
found = PATTERN_WORD.findall(tags)
for word in found:
if re.search(f", ((\w+) )+{word}, ", tags):
tags = tags.replace(f", {word}, ", "")
tags = tags.replace(", , ", ", ")
assert tags.startswith(", ") and tags.endswith(", ")
tags = tags[2:-2]
return tags
@ -88,13 +138,23 @@ def main(args):
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
else:
metadata[image_key]['tags'] = clean_tags(image_key, tags)
org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else:
metadata[image_key]['caption'] = clean_caption(caption)
org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
@ -108,6 +168,7 @@ if __name__ == '__main__':
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--debug", action="store_true", help="debug mode")
args, unknown = parser.parse_known_args()
if len(unknown) == 1:

View File

@ -11,14 +11,55 @@ import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from blip.blip import blip_decoder
# from Salesforce_BLIP.models.blip import blip_decoder
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_SIZE = 384
# 正方形でいいのか? という気がするがソースがそうなので
IMAGE_TRANSFORM = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# 共通化したいが微妙に処理が異なる……
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# fix the seed for reproducibility
seed = args.seed # + utils.get_rank()
seed = args.seed # + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
@ -31,24 +72,15 @@ def main(args):
os.chdir('finetune')
print(f"load images from {args.train_data_dir}")
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}")
image_size = 384
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
# 正方形でいいのか? という気がするがソースがそうなので
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# captioningする
def run_batch(path_imgs):
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
@ -66,18 +98,35 @@ def main(args):
if args.debug:
print(image_path, caption)
b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
raw_image = Image.open(image_path)
if raw_image.mode != "RGB":
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
raw_image = raw_image.convert("RGB")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingTransformDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
image = transform(raw_image)
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
img_tensor, image_path = data
if img_tensor is None:
try:
raw_image = Image.open(image_path)
if raw_image.mode != 'RGB':
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, img_tensor))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
@ -95,6 +144,8 @@ if __name__ == '__main__':
parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")

View File

@ -0,0 +1,145 @@
import argparse
import os
import re
from PIL import Image
from tqdm import tqdm
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
re.compile(r'with the words "'),
re.compile(r'word \w+ on it'),
re.compile(r'that says the word \w+ on it'),
re.compile('that says\'the word "( on it)?'),
]
# 誤検知しまくりの with the word xxxx を消す
def remove_words(captions, debug):
removed_caps = []
for caption in captions:
cap = caption
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
removed_caps.append(cap)
return removed_caps
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
# input_idsがバッチサイズと同じ件数である必要があるバッチサイズはこの関数から参照できないので外から渡す
# ここより上で置き換えようとするとすごく大変
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
if input_ids.size()[0] != curr_batch_size[0]:
input_ids = input_ids.repeat(curr_batch_size[0], 1)
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
# captioningする
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
if args.remove_words:
captions = remove_words(captions, args.debug)
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument("--remove_words", action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
main(args)

View File

@ -1,26 +1,24 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
import library.train_util as train_util
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
if args.in_json is None and os.path.isfile(args.out_json):
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
@ -28,11 +26,10 @@ def main(args):
print("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
with open(caption_path, "rt", encoding='utf-8') as f:
caption = f.readlines()[0].strip()
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
@ -42,8 +39,7 @@ def main(args):
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
@ -51,12 +47,15 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()

View File

@ -1,26 +1,24 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
import library.train_util as train_util
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
if args.in_json is None and os.path.isfile(args.out_json):
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
@ -28,11 +26,10 @@ def main(args):
print("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = os.path.splitext(image_path)[0] + '.txt'
with open(tags_path, "rt", encoding='utf-8') as f:
tags = f.readlines()[0].strip()
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip()
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
@ -42,8 +39,8 @@ def main(args):
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
@ -51,9 +48,14 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--caption_extension", type=str, default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
args = parser.parse_args()

View File

@ -1,20 +1,16 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from tqdm import tqdm
import numpy as np
from diffusers import AutoencoderKL
from PIL import Image
import cv2
import torch
from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -26,6 +22,16 @@ IMAGE_TRANSFORMS = transforms.Compose(
)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def get_latents(vae, images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors)
@ -35,9 +41,18 @@ def get_latents(vae, images, weight_dtype):
return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
else:
base_name = image_key
if flip:
base_name += '_flip'
return os.path.join(data_dir, base_name)
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
@ -70,15 +85,56 @@ def main(args):
buckets_imgs = [[] for _ in range(len(bucket_resos))]
bucket_counts = [0 for _ in range(len(bucket_resos))]
img_ar_errors = []
for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)):
def process_batch(is_last):
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
for (image_key, _, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
np.savez(npz_file_name, latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
np.savez(npz_file_name, latent)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
for data_entry in tqdm(data, smoothing=0.0):
if data_entry[0] is None:
continue
img_tensor, image_path = data_entry[0]
if img_tensor is not None:
image = transforms.functional.to_pil_image(img_tensor)
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
if image_key not in metadata:
metadata[image_key] = {}
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
# 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変
aspect_ratio = image.width / image.height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
@ -102,6 +158,25 @@ def main(args):
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break
dat = np.load(npz_file)['arr_0']
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
continue
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
@ -123,25 +198,10 @@ def main(args):
metadata[image_key]['train_resolution'] = reso
# バッチを推論するか判定して推論する
is_last = i == len(image_paths) - 1
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
process_batch(False)
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name), latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent)
bucket.clear()
# 残りを処理する
process_batch(True)
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
print(f"bucket {i} {reso}: {count}")
@ -162,8 +222,10 @@ if __name__ == '__main__':
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
@ -174,6 +236,8 @@ if __name__ == '__main__':
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--flip_aug", action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
args = parser.parse_args()
main(args)

View File

@ -1,180 +0,0 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from tqdm import tqdm
import numpy as np
from diffusers import AutoencoderKL
from PIL import Image
import cv2
import torch
from torchvision import transforms
import library.model_util as model_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def get_latents(vae, images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors)
img_tensors = img_tensors.to(DEVICE, weight_dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
return latents
def main(args):
# image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
# glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
# print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
vae.eval()
vae.to(DEVICE, dtype=weight_dtype)
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
max_reso, args.min_bucket_reso, args.max_bucket_reso)
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
buckets_imgs = [[] for _ in range(len(bucket_resos))]
bucket_counts = [0 for _ in range(len(bucket_resos))]
img_ar_errors = []
for i, image_path in enumerate(tqdm(metadata, smoothing=0.0)):
image_key = image_path
if image_key not in metadata:
metadata[image_key] = {}
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
aspect_ratio = image.width / image.height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
reso = bucket_resos[bucket_id]
ar_error = ar_errors[bucket_id]
img_ar_errors.append(abs(ar_error))
# どのサイズにリサイズするか→トリミングする方向で
if ar_error <= 0: # 横が長い→縦を合わせる
scale = reso[1] / image.height
else:
scale = reso[0] / image.width
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
assert resized_size[0] == reso[0] or resized_size[1] == reso[
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size//2:trim_size//2 + reso[0]]
elif resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size//2:trim_size//2 + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
# # debug
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
# バッチへ追加
buckets_imgs[bucket_id].append((image_key, reso, image))
bucket_counts[bucket_id] += 1
metadata[image_key]['train_resolution'] = reso
# バッチを推論するか判定して推論する
is_last = i == len(metadata) - 1
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name), latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name + '_flip'), latent)
bucket.clear()
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
print(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--flip_aug", action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
args = parser.parse_args()
main(args)

View File

@ -1,6 +1,3 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import csv
import glob
@ -12,32 +9,87 @@ from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
import library.train_util as train_util
# from wd14 tagger
IMAGE_SIZE = 448
WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger'
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print("downloading wd14 tagger model from hf_hub")
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
else:
print("using existing wd14 tagger model")
# 画像を読み込む
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
print("loading model and labels")
@ -72,7 +124,7 @@ def main(args):
# Everything else is tags: pick any where prediction confidence > threshold
tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh:
if p >= args.thresh and i < len(tags):
tag_text += ", " + tags[i]
if len(tag_text) > 0:
@ -83,34 +135,37 @@ def main(args):
if args.debug:
print(image_path, tag_text)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
img = Image.open(image_path) # cv2は日本語ファイル名で死ぬのとモード変換したいのでpillowで開く
if img.mode != 'RGB':
img = img.convert("RGB")
img = np.array(img)
img = img[:, :, ::-1] # RGB->BGR
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
# pad to square
size = max(img.shape[0:2])
pad_x = size - img.shape[1]
pad_y = size - img.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
img = np.pad(img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
image, image_path = data
if image is not None:
image = image.detach().numpy()
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
# cv2.imshow("img", img)
# cv2.waitKey()
# cv2.destroyAllWindows()
img = img.astype(np.float32)
b_imgs.append((image_path, img))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
@ -121,7 +176,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=WD14_TAGGER_REPO,
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
@ -129,6 +184,8 @@ if __name__ == '__main__':
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")

View File

@ -292,7 +292,7 @@ def train_model(
subprocess.run(run_cmd)
image_num = len(
[f for f in os.listdir(image_folder) if f.endswith('.npz')]
[f for f in os.listdir(image_folder) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')]
)
print(f'image_num = {image_num}')

View File

@ -470,6 +470,9 @@ class PipelineLike():
self.scheduler = scheduler
self.safety_checker = None
# Textual Inversion
self.token_replacements = {}
# CLIP guidance
self.clip_guidance_scale = clip_guidance_scale
self.clip_image_guidance_scale = clip_image_guidance_scale
@ -484,6 +487,19 @@ class PipelineLike():
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
# Textual Inversion
def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids
def replace_token(self, tokens):
new_tokens = []
for token in tokens:
if token in self.token_replacements:
new_tokens.extend(self.token_replacements[token])
else:
new_tokens.append(token)
return new_tokens
# region xformersとか使う部分独自に書き換えるので関係なし
def enable_xformers_memory_efficient_attention(self):
r"""
@ -1507,6 +1523,9 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
token = pipe.replace_token(token)
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
@ -1826,12 +1845,12 @@ def main(args):
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else:
print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
tokenizer = pipe.tokenizer
del pipe
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = loading_pipe.text_encoder
vae = loading_pipe.vae
unet = loading_pipe.unet
tokenizer = loading_pipe.tokenizer
del loading_pipe
# VAEを読み込む
if args.vae is not None:
@ -2039,6 +2058,44 @@ def main(args):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
# Textual Inversionを処理する
if args.textual_inversion_embeddings:
token_ids_embeds = []
for embeds_file in args.textual_inversion_embeddings:
if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file
data = load_file(embeds_file)
else:
data = torch.load(embeds_file, map_location="cpu")
embeds = next(iter(data.values()))
if type(embeds) != torch.Tensor:
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
num_vectors_per_token = embeds.size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens = tokenizer.add_tokens(token_strings)
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
if num_vectors_per_token > 1:
pipe.add_token_replacement(token_ids[0], token_ids)
token_ids_embeds.append((token_ids, embeds))
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
for token_ids, embeds in token_ids_embeds:
for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed
# promptを取得する
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
@ -2157,8 +2214,8 @@ def main(args):
os.makedirs(args.outdir, exist_ok=True)
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for iter in range(args.n_iter):
print(f"iteration {iter+1}/{args.n_iter}")
for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7fffffff)
# バッチ処理の関数
@ -2527,6 +2584,8 @@ if __name__ == '__main__':
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')

View File

@ -19,7 +19,7 @@ def UI(username, password):
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
interface = gr.Blocks(css=css, title="Kohya_ss GUI")
with interface:
with gr.Tab('Dreambooth'):

View File

@ -109,11 +109,11 @@ def gradio_extract_lora_tab():
)
with gr.Row():
dim = gr.Slider(
minimum=1,
maximum=128,
minimum=4,
maximum=1024,
label='Network Dimension',
value=8,
step=1,
value=128,
step=4,
interactive=True,
)
v2 = gr.Checkbox(label='v2', value=False, interactive=True)

126
library/git_caption_gui.py Normal file
View File

@ -0,0 +1,126 @@
import gradio as gr
from easygui import msgbox
import subprocess
import os
from .common_gui import get_folder_path, add_pre_postfix
def caption_images(
train_data_dir,
caption_ext,
batch_size,
max_data_loader_n_workers,
max_length,
model_id,
prefix,
postfix,
):
# Check for images_dir_input
if train_data_dir == '':
msgbox('Image folder is missing...')
return
if caption_ext == '':
msgbox('Please provide an extension for the caption files.')
return
print(f'GIT captioning files in {train_data_dir}...')
run_cmd = f'.\\venv\\Scripts\\python.exe "finetune/make_captions.py"'
if not model_id == '':
run_cmd += f' --model_id="{model_id}"'
run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
run_cmd += f' --max_length="{int(max_length)}"'
if caption_ext != '':
run_cmd += f' --caption_extension="{caption_ext}"'
run_cmd += f' "{train_data_dir}"'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
# Add prefix and postfix
add_pre_postfix(
folder=train_data_dir,
caption_file_ext=caption_ext,
prefix=prefix,
postfix=postfix,
)
print('...captioning done')
###
# Gradio UI
###
def gradio_git_caption_gui_tab():
with gr.Tab('GIT Captioning'):
gr.Markdown(
'This utility will use GIT to caption files for each images in a folder.'
)
with gr.Row():
train_data_dir = gr.Textbox(
label='Image folder to caption',
placeholder='Directory containing the images to caption',
interactive=True,
)
button_train_data_dir_input = gr.Button(
'📂', elem_id='open_folder_small'
)
button_train_data_dir_input.click(
get_folder_path, outputs=train_data_dir
)
with gr.Row():
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extention for caption file. eg: .caption, .txt',
value='.txt',
interactive=True,
)
prefix = gr.Textbox(
label='Prefix to add to BLIP caption',
placeholder='(Optional)',
interactive=True,
)
postfix = gr.Textbox(
label='Postfix to add to BLIP caption',
placeholder='(Optional)',
interactive=True,
)
batch_size = gr.Number(
value=1, label='Batch size', interactive=True
)
with gr.Row():
max_data_loader_n_workers = gr.Number(
value=2, label='Number of workers', interactive=True
)
max_length = gr.Number(
value=75, label='Max length', interactive=True
)
model_id = gr.Textbox(
label="Model",
placeholder="(Optional) model id for GIT in Hugging Face", interactive=True
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[
train_data_dir,
caption_ext,
batch_size,
max_data_loader_n_workers,
max_length,
model_id,
prefix,
postfix,
],
)

View File

@ -45,6 +45,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
class ImageInfo():
@ -87,6 +88,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.enable_bucket = False
self.min_bucket_reso = None
self.max_bucket_reso = None
self.tag_frequency = {}
self.bucket_info = None
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@ -115,6 +117,16 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {}
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
if tag and not tag.isspace():
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
def disable_token_padding(self):
self.token_padding_disabled = True
@ -240,14 +252,15 @@ class BaseDataset(torch.utils.data.Dataset):
print("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む")
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
# only show bucket info if there is an actual image in it
if len(img_keys) > 0:
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}")
# 参照用indexを作る
self.buckets_indices: list(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.buckets):
@ -545,6 +558,8 @@ class DreamBoothDataset(BaseDataset):
cap_for_img = read_caption(img_path)
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
return n_repeats, img_paths, captions
print("prepare train images.")
@ -553,10 +568,13 @@ class DreamBoothDataset(BaseDataset):
for dir in train_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
num_train_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info)
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
@ -570,9 +588,11 @@ class DreamBoothDataset(BaseDataset):
for dir in reg_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
num_reg_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info)
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.")
@ -617,6 +637,7 @@ class FineTuningDataset(BaseDataset):
self.train_data_dir = train_data_dir
self.batch_size = batch_size
tags_list = []
for image_key, img_md in metadata.items():
# path情報を作る
if os.path.exists(image_key):
@ -633,6 +654,7 @@ class FineTuningDataset(BaseDataset):
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ', ' + tags
tags_list.append(tags)
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
@ -646,7 +668,8 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
# check existence of all npz files
if not self.color_aug:
@ -667,6 +690,8 @@ class FineTuningDataset(BaseDataset):
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
elif not npz_all:
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
if self.flip_aug:
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
@ -756,15 +781,30 @@ def debug_dataset(train_dataset, show_input_ids=False):
break
def glob_images(dir, base):
def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == '*':
img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext)))
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext))))
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort()
return img_paths
def glob_images_pathlib(dir_path, recursive):
image_paths = []
if recursive:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.rglob('*' + ext))
else:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob('*' + ext))
# image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort()
return image_paths
# endregion
@ -1495,5 +1535,30 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
# endregion
# region 前処理用
class ImageLoadingDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor_pil = transforms.functional.pil_to_tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor_pil, img_path)
# endregion

View File

@ -9,6 +9,7 @@ import argparse
from library.basic_caption_gui import gradio_basic_caption_gui_tab
from library.convert_model_gui import gradio_convert_model_tab
from library.blip_caption_gui import gradio_blip_caption_gui_tab
from library.git_caption_gui import gradio_git_caption_gui_tab
from library.wd14_caption_gui import gradio_wd14_caption_gui_tab
@ -23,6 +24,7 @@ def utilities_tab(
with gr.Tab('Captioning'):
gradio_basic_caption_gui_tab()
gradio_blip_caption_gui_tab()
gradio_git_caption_gui_tab()
gradio_wd14_caption_gui_tab()
gradio_convert_model_tab()

View File

@ -291,11 +291,11 @@ def train_model(
if unet_lr == '':
unet_lr = 0
if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
msgbox(
'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
)
return
# if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
# msgbox(
# 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
# )
# return
# Get a list of all subfolders in train_data_dir
subfolders = [
@ -383,15 +383,26 @@ def train_model(
if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
run_cmd += f' --network_module=networks.lora'
if not float(text_encoder_lr) == 0:
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
run_cmd += f' --unet_lr={unet_lr}'
elif not (float(text_encoder_lr) == 0):
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
run_cmd += f' --network_train_text_encoder_only'
else:
run_cmd += f' --unet_lr={unet_lr}'
run_cmd += f' --network_train_unet_only'
else:
run_cmd += f' --network_train_unet_only'
if not float(unet_lr) == 0:
run_cmd += f' --unet_lr={unet_lr}'
else:
run_cmd += f' --network_train_text_encoder_only'
if float(text_encoder_lr) == 0:
msgbox(
'Please input learning rate values.'
)
return
run_cmd += f' --network_dim={network_dim}'
if not lora_network_weights == '':
run_cmd += f' --network_weights="{lora_network_weights}"'
if int(gradient_accumulation_steps) > 1:
@ -400,6 +411,8 @@ def train_model(
run_cmd += f' --output_name="{output_name}"'
if not lr_scheduler_num_cycles == '':
run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"'
else:
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
if not lr_scheduler_power == '':
run_cmd += f' --output_name="{lr_scheduler_power}"'
@ -612,19 +625,19 @@ def lora_tab(
placeholder='Optional',
)
network_dim = gr.Slider(
minimum=1,
maximum=128,
minimum=4,
maximum=1024,
label='Network Rank (Dimension)',
value=8,
step=1,
step=4,
interactive=True,
)
network_alpha = gr.Slider(
minimum=1,
maximum=128,
minimum=4,
maximum=1024,
label='Network Alpha',
value=1,
step=1,
step=4,
interactive=True,
)
with gr.Row():

View File

@ -1,5 +1,5 @@
accelerate==0.15.0
transformers==4.25.1
transformers==4.26.0
ftfy
albumentations
opencv-python
@ -9,7 +9,7 @@ pytorch_lightning
bitsandbytes==0.35.0
tensorboard
safetensors==0.2.6
gradio==3.15.0
gradio==3.16.2
altair
easygui
tk

View File

@ -0,0 +1,66 @@
import os
import cv2
import argparse
import shutil
import math
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Create destination folder if it does not exist
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Iterate through all files in src_img_folder
for filename in os.listdir(src_img_folder):
# Check if the image is png, jpg or webp
if not filename.endswith(('.png', '.jpg', '.webp')):
# Copy the file to the destination folder if not png, jpg or webp
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
continue
# Load image
img = cv2.imread(os.path.join(src_img_folder, filename))
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
# Check if the image needs resizing
if current_pixels > max_pixels:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height))
# Calculate the new height and width that are divisible by divisible_by
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
# Save resized image in dst_img_folder
cv2.imwrite(os.path.join(dst_img_folder, filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]}")
def main():
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
parser.add_argument('--max_resolution', type=str, help='Maximum resolution in the format "512x512"', default="512x512")
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=2)
args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
if __name__ == '__main__':
main()

View File

@ -1,3 +1,6 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer
from typing import Optional, Union
import importlib
import argparse
import gc
@ -40,9 +43,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
from typing import Optional, Union
from torch.optim import Optimizer
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
def get_scheduler_fix(
name: Union[str, SchedulerType],
@ -52,53 +52,53 @@ def get_scheduler_fix(
num_cycles: int = 1,
power: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def train(args):
@ -135,7 +135,7 @@ def train(args):
train_util.debug_dataset(train_dataset)
return
if len(train_dataset) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります")
return
# acceleratorを準備する
@ -335,6 +335,7 @@ def train(args):
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment # will not be updated after training
}