Update to latest kohya_ss sd-script code
This commit is contained in:
parent
c8f4c9d6e8
commit
20e62af1a6
@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。(bf1
|
|||||||
cd sd-scripts
|
cd sd-scripts
|
||||||
git pull
|
git pull
|
||||||
.\venv\Scripts\activate
|
.\venv\Scripts\activate
|
||||||
pip install --upgrade -r <requirement file name>
|
pip install --upgrade -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
コマンドが成功すれば新しいバージョンが使用できます。
|
コマンドが成功すれば新しいバージョンが使用できます。
|
||||||
|
17
README.md
17
README.md
@ -143,6 +143,23 @@ Then redo the installation instruction within the kohya_ss venv.
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 2023/02/03
|
||||||
|
- Increase max LoRA rank (dim) size to 1024.
|
||||||
|
- Update finetune preprocessing scripts.
|
||||||
|
- ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev!
|
||||||
|
- The default weights of ``tag_images_by_wd14_tagger.py`` is now ``SmilingWolf/wd-v1-4-convnext-tagger-v2``. You can specify another model id from ``SmilingWolf`` by ``--repo_id`` option. Thanks to SmilingWolf for the great work.
|
||||||
|
- To change the weight, remove ``wd14_tagger_model`` folder, and run the script again.
|
||||||
|
- ``--max_data_loader_n_workers`` option is added to each script. This option uses the DataLoader for data loading to speed up loading, 20%~30% faster.
|
||||||
|
- Please specify 2 or 4, depends on the number of CPU cores.
|
||||||
|
- ``--recursive`` option is added to ``merge_dd_tags_to_metadata.py`` and ``merge_captions_to_metadata.py``, only works with ``--full_path``.
|
||||||
|
- ``make_captions_by_git.py`` is added. It uses [GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) for captioning.
|
||||||
|
- ``requirements.txt`` is updated. If you use this script, [please update the libraries](https://github.com/kohya-ss/sd-scripts#upgrade).
|
||||||
|
- Usage is almost the same as ``make_captions.py``, but batch size should be smaller.
|
||||||
|
- ``--remove_words`` option removes as much text as possible (such as ``the word "XXXX" on it``).
|
||||||
|
- ``--skip_existing`` option is added to ``prepare_buckets_latents.py``. Images with existing npz files are ignored by this option.
|
||||||
|
- ``clean_captions_and_tags.py`` is updated to remove duplicated or conflicting tags, e.g. ``shirt`` is removed when ``white shirt`` exists. if ``black hair`` is with ``red hair``, both are removed.
|
||||||
|
- Tag frequency is added to the metadata in ``train_network.py``. Thanks to space-nuko!
|
||||||
|
- __All tags and number of occurrences of the tag are recorded.__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
|
||||||
* 2023/01/30 (v20.5.2):
|
* 2023/01/30 (v20.5.2):
|
||||||
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev!
|
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev!
|
||||||
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py``
|
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py``
|
||||||
|
@ -5,13 +5,32 @@ import argparse
|
|||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
|
||||||
|
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
|
||||||
|
PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
|
||||||
|
PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
|
||||||
|
|
||||||
|
# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
|
||||||
|
PATTERNS_REMOVE_IN_MULTI = [
|
||||||
|
PATTERN_HAIR_LENGTH,
|
||||||
|
PATTERN_HAIR_CUT,
|
||||||
|
re.compile(r', [\w\-]+ eyes, '),
|
||||||
|
re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
|
||||||
|
# 複数の髪型定義がある場合は削除する
|
||||||
|
re.compile(
|
||||||
|
r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def clean_tags(image_key, tags):
|
def clean_tags(image_key, tags):
|
||||||
# replace '_' to ' '
|
# replace '_' to ' '
|
||||||
|
tags = tags.replace('^_^', '^@@@^')
|
||||||
tags = tags.replace('_', ' ')
|
tags = tags.replace('_', ' ')
|
||||||
|
tags = tags.replace('^@@@^', '^_^')
|
||||||
|
|
||||||
# remove rating: deepdanbooruのみ
|
# remove rating: deepdanbooruのみ
|
||||||
tokens = tags.split(", rating")
|
tokens = tags.split(", rating")
|
||||||
@ -26,6 +45,37 @@ def clean_tags(image_key, tags):
|
|||||||
print(f"{image_key} {tags}")
|
print(f"{image_key} {tags}")
|
||||||
tags = tokens[0]
|
tags = tokens[0]
|
||||||
|
|
||||||
|
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
|
||||||
|
|
||||||
|
# 複数の人物がいる場合は髪色等のタグを削除する
|
||||||
|
if 'girls' in tags or 'boys' in tags:
|
||||||
|
for pat in PATTERNS_REMOVE_IN_MULTI:
|
||||||
|
found = pat.findall(tags)
|
||||||
|
if len(found) > 1: # 二つ以上、タグがある
|
||||||
|
tags = pat.sub("", tags)
|
||||||
|
|
||||||
|
# 髪の特殊対応
|
||||||
|
srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
|
||||||
|
if srch_hair_len:
|
||||||
|
org = srch_hair_len.group()
|
||||||
|
tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
|
||||||
|
|
||||||
|
found = PATTERN_HAIR.findall(tags)
|
||||||
|
if len(found) > 1:
|
||||||
|
tags = PATTERN_HAIR.sub("", tags)
|
||||||
|
|
||||||
|
if srch_hair_len:
|
||||||
|
tags = tags.replace(", @@@, ", org) # 戻す
|
||||||
|
|
||||||
|
# white shirtとshirtみたいな重複タグの削除
|
||||||
|
found = PATTERN_WORD.findall(tags)
|
||||||
|
for word in found:
|
||||||
|
if re.search(f", ((\w+) )+{word}, ", tags):
|
||||||
|
tags = tags.replace(f", {word}, ", "")
|
||||||
|
|
||||||
|
tags = tags.replace(", , ", ", ")
|
||||||
|
assert tags.startswith(", ") and tags.endswith(", ")
|
||||||
|
tags = tags[2:-2]
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
@ -88,13 +138,23 @@ def main(args):
|
|||||||
if tags is None:
|
if tags is None:
|
||||||
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
||||||
else:
|
else:
|
||||||
metadata[image_key]['tags'] = clean_tags(image_key, tags)
|
org = tags
|
||||||
|
tags = clean_tags(image_key, tags)
|
||||||
|
metadata[image_key]['tags'] = tags
|
||||||
|
if args.debug and org != tags:
|
||||||
|
print("FROM: " + org)
|
||||||
|
print("TO: " + tags)
|
||||||
|
|
||||||
caption = metadata[image_key].get('caption')
|
caption = metadata[image_key].get('caption')
|
||||||
if caption is None:
|
if caption is None:
|
||||||
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
||||||
else:
|
else:
|
||||||
metadata[image_key]['caption'] = clean_caption(caption)
|
org = caption
|
||||||
|
caption = clean_caption(caption)
|
||||||
|
metadata[image_key]['caption'] = caption
|
||||||
|
if args.debug and org != caption:
|
||||||
|
print("FROM: " + org)
|
||||||
|
print("TO: " + caption)
|
||||||
|
|
||||||
# metadataを書き出して終わり
|
# metadataを書き出して終わり
|
||||||
print(f"writing metadata: {args.out_json}")
|
print(f"writing metadata: {args.out_json}")
|
||||||
@ -108,6 +168,7 @@ if __name__ == '__main__':
|
|||||||
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
if len(unknown) == 1:
|
if len(unknown) == 1:
|
||||||
|
@ -11,11 +11,52 @@ import torch
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
from blip.blip import blip_decoder
|
from blip.blip import blip_decoder
|
||||||
# from Salesforce_BLIP.models.blip import blip_decoder
|
import library.train_util as train_util
|
||||||
|
|
||||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_SIZE = 384
|
||||||
|
|
||||||
|
# 正方形でいいのか? という気がするがソースがそうなので
|
||||||
|
IMAGE_TRANSFORM = transforms.Compose([
|
||||||
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||||
|
])
|
||||||
|
|
||||||
|
# 共通化したいが微妙に処理が異なる……
|
||||||
|
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, image_paths):
|
||||||
|
self.images = image_paths
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img_path = self.images[idx]
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = Image.open(img_path).convert("RGB")
|
||||||
|
# convert to tensor temporarily so dataloader will accept it
|
||||||
|
tensor = IMAGE_TRANSFORM(image)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (tensor, img_path)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_remove_corrupted(batch):
|
||||||
|
"""Collate function that allows to remove corrupted examples in the
|
||||||
|
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||||
|
The 'None's in the batch are removed.
|
||||||
|
"""
|
||||||
|
# Filter out all the Nones (corrupted examples)
|
||||||
|
batch = list(filter(lambda x: x is not None, batch))
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# fix the seed for reproducibility
|
# fix the seed for reproducibility
|
||||||
seed = args.seed # + utils.get_rank()
|
seed = args.seed # + utils.get_rank()
|
||||||
@ -31,24 +72,15 @@ def main(args):
|
|||||||
os.chdir('finetune')
|
os.chdir('finetune')
|
||||||
|
|
||||||
print(f"load images from {args.train_data_dir}")
|
print(f"load images from {args.train_data_dir}")
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
image_paths = train_util.glob_images(args.train_data_dir)
|
||||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
print(f"loading BLIP caption: {args.caption_weights}")
|
print(f"loading BLIP caption: {args.caption_weights}")
|
||||||
image_size = 384
|
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
|
||||||
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(DEVICE)
|
model = model.to(DEVICE)
|
||||||
print("BLIP loaded")
|
print("BLIP loaded")
|
||||||
|
|
||||||
# 正方形でいいのか? という気がするがソースがそうなので
|
|
||||||
transform = transforms.Compose([
|
|
||||||
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
|
||||||
])
|
|
||||||
|
|
||||||
# captioningする
|
# captioningする
|
||||||
def run_batch(path_imgs):
|
def run_batch(path_imgs):
|
||||||
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
||||||
@ -66,15 +98,32 @@ def main(args):
|
|||||||
if args.debug:
|
if args.debug:
|
||||||
print(image_path, caption)
|
print(image_path, caption)
|
||||||
|
|
||||||
b_imgs = []
|
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||||
for image_path in tqdm(image_paths, smoothing=0.0):
|
if args.max_data_loader_n_workers is not None:
|
||||||
raw_image = Image.open(image_path)
|
dataset = ImageLoadingTransformDataset(image_paths)
|
||||||
if raw_image.mode != "RGB":
|
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
||||||
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
|
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||||
raw_image = raw_image.convert("RGB")
|
else:
|
||||||
|
data = [[(None, ip)] for ip in image_paths]
|
||||||
|
|
||||||
image = transform(raw_image)
|
b_imgs = []
|
||||||
b_imgs.append((image_path, image))
|
for data_entry in tqdm(data, smoothing=0.0):
|
||||||
|
for data in data_entry:
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_tensor, image_path = data
|
||||||
|
if img_tensor is None:
|
||||||
|
try:
|
||||||
|
raw_image = Image.open(image_path)
|
||||||
|
if raw_image.mode != 'RGB':
|
||||||
|
raw_image = raw_image.convert("RGB")
|
||||||
|
img_tensor = IMAGE_TRANSFORM(raw_image)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
b_imgs.append((image_path, img_tensor))
|
||||||
if len(b_imgs) >= args.batch_size:
|
if len(b_imgs) >= args.batch_size:
|
||||||
run_batch(b_imgs)
|
run_batch(b_imgs)
|
||||||
b_imgs.clear()
|
b_imgs.clear()
|
||||||
@ -95,6 +144,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--beam_search", action="store_true",
|
parser.add_argument("--beam_search", action="store_true",
|
||||||
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||||
|
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||||
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||||
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||||
|
145
finetune/make_captions_by_git.py
Normal file
145
finetune/make_captions_by_git.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||||
|
from transformers.generation.utils import GenerationMixin
|
||||||
|
|
||||||
|
import library.train_util as train_util
|
||||||
|
|
||||||
|
|
||||||
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
PATTERN_REPLACE = [
|
||||||
|
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
|
||||||
|
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
|
||||||
|
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
|
||||||
|
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
|
||||||
|
re.compile(r'with the words "'),
|
||||||
|
re.compile(r'word \w+ on it'),
|
||||||
|
re.compile(r'that says the word \w+ on it'),
|
||||||
|
re.compile('that says\'the word "( on it)?'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 誤検知しまくりの with the word xxxx を消す
|
||||||
|
|
||||||
|
|
||||||
|
def remove_words(captions, debug):
|
||||||
|
removed_caps = []
|
||||||
|
for caption in captions:
|
||||||
|
cap = caption
|
||||||
|
for pat in PATTERN_REPLACE:
|
||||||
|
cap = pat.sub("", cap)
|
||||||
|
if debug and cap != caption:
|
||||||
|
print(caption)
|
||||||
|
print(cap)
|
||||||
|
removed_caps.append(cap)
|
||||||
|
return removed_caps
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_remove_corrupted(batch):
|
||||||
|
"""Collate function that allows to remove corrupted examples in the
|
||||||
|
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||||
|
The 'None's in the batch are removed.
|
||||||
|
"""
|
||||||
|
# Filter out all the Nones (corrupted examples)
|
||||||
|
batch = list(filter(lambda x: x is not None, batch))
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
||||||
|
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
||||||
|
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
||||||
|
|
||||||
|
# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
|
||||||
|
# ここより上で置き換えようとするとすごく大変
|
||||||
|
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
|
||||||
|
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
|
||||||
|
if input_ids.size()[0] != curr_batch_size[0]:
|
||||||
|
input_ids = input_ids.repeat(curr_batch_size[0], 1)
|
||||||
|
return input_ids
|
||||||
|
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||||
|
|
||||||
|
print(f"load images from {args.train_data_dir}")
|
||||||
|
image_paths = train_util.glob_images(args.train_data_dir)
|
||||||
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
|
# できればcacheに依存せず明示的にダウンロードしたい
|
||||||
|
print(f"loading GIT: {args.model_id}")
|
||||||
|
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
||||||
|
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
||||||
|
print("GIT loaded")
|
||||||
|
|
||||||
|
# captioningする
|
||||||
|
def run_batch(path_imgs):
|
||||||
|
imgs = [im for _, im in path_imgs]
|
||||||
|
|
||||||
|
curr_batch_size[0] = len(path_imgs)
|
||||||
|
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
||||||
|
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
||||||
|
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
if args.remove_words:
|
||||||
|
captions = remove_words(captions, args.debug)
|
||||||
|
|
||||||
|
for (image_path, _), caption in zip(path_imgs, captions):
|
||||||
|
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||||
|
f.write(caption + "\n")
|
||||||
|
if args.debug:
|
||||||
|
print(image_path, caption)
|
||||||
|
|
||||||
|
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||||
|
if args.max_data_loader_n_workers is not None:
|
||||||
|
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||||
|
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
||||||
|
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||||
|
else:
|
||||||
|
data = [[(None, ip)] for ip in image_paths]
|
||||||
|
|
||||||
|
b_imgs = []
|
||||||
|
for data_entry in tqdm(data, smoothing=0.0):
|
||||||
|
for data in data_entry:
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image, image_path = data
|
||||||
|
if image is None:
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
if image.mode != 'RGB':
|
||||||
|
image = image.convert("RGB")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
b_imgs.append((image_path, image))
|
||||||
|
if len(b_imgs) >= args.batch_size:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
b_imgs.clear()
|
||||||
|
|
||||||
|
if len(b_imgs) > 0:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||||
|
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
|
||||||
|
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||||
|
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||||
|
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
|
||||||
|
parser.add_argument("--remove_words", action="store_true",
|
||||||
|
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
@ -1,26 +1,24 @@
|
|||||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
|
||||||
# (c) 2022 Kohya S. @kohya_ss
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import library.train_util as train_util
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
|
||||||
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
|
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if args.in_json is None and os.path.isfile(args.out_json):
|
if args.in_json is None and Path(args.out_json).is_file():
|
||||||
args.in_json = args.out_json
|
args.in_json = args.out_json
|
||||||
|
|
||||||
if args.in_json is not None:
|
if args.in_json is not None:
|
||||||
print(f"loading existing metadata: {args.in_json}")
|
print(f"loading existing metadata: {args.in_json}")
|
||||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
||||||
metadata = json.load(f)
|
|
||||||
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
||||||
else:
|
else:
|
||||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||||
@ -28,11 +26,10 @@ def main(args):
|
|||||||
|
|
||||||
print("merge caption texts to metadata json.")
|
print("merge caption texts to metadata json.")
|
||||||
for image_path in tqdm(image_paths):
|
for image_path in tqdm(image_paths):
|
||||||
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
|
caption_path = image_path.with_suffix(args.caption_extension)
|
||||||
with open(caption_path, "rt", encoding='utf-8') as f:
|
caption = caption_path.read_text(encoding='utf-8').strip()
|
||||||
caption = f.readlines()[0].strip()
|
|
||||||
|
|
||||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
image_key = str(image_path) if args.full_path else image_path.stem
|
||||||
if image_key not in metadata:
|
if image_key not in metadata:
|
||||||
metadata[image_key] = {}
|
metadata[image_key] = {}
|
||||||
|
|
||||||
@ -42,8 +39,7 @@ def main(args):
|
|||||||
|
|
||||||
# metadataを書き出して終わり
|
# metadataを書き出して終わり
|
||||||
print(f"writing metadata: {args.out_json}")
|
print(f"writing metadata: {args.out_json}")
|
||||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
||||||
json.dump(metadata, f, indent=2)
|
|
||||||
print("done!")
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
@ -51,12 +47,15 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
parser.add_argument("--in_json", type=str,
|
||||||
|
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
||||||
parser.add_argument("--caption_extention", type=str, default=None,
|
parser.add_argument("--caption_extention", type=str, default=None,
|
||||||
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
|
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
|
||||||
parser.add_argument("--full_path", action="store_true",
|
parser.add_argument("--full_path", action="store_true",
|
||||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||||
|
parser.add_argument("--recursive", action="store_true",
|
||||||
|
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
||||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -1,26 +1,24 @@
|
|||||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
|
||||||
# (c) 2022 Kohya S. @kohya_ss
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import library.train_util as train_util
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
|
||||||
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
|
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if args.in_json is None and os.path.isfile(args.out_json):
|
if args.in_json is None and Path(args.out_json).is_file():
|
||||||
args.in_json = args.out_json
|
args.in_json = args.out_json
|
||||||
|
|
||||||
if args.in_json is not None:
|
if args.in_json is not None:
|
||||||
print(f"loading existing metadata: {args.in_json}")
|
print(f"loading existing metadata: {args.in_json}")
|
||||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
||||||
metadata = json.load(f)
|
|
||||||
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||||
else:
|
else:
|
||||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||||
@ -28,11 +26,10 @@ def main(args):
|
|||||||
|
|
||||||
print("merge tags to metadata json.")
|
print("merge tags to metadata json.")
|
||||||
for image_path in tqdm(image_paths):
|
for image_path in tqdm(image_paths):
|
||||||
tags_path = os.path.splitext(image_path)[0] + '.txt'
|
tags_path = image_path.with_suffix(args.caption_extension)
|
||||||
with open(tags_path, "rt", encoding='utf-8') as f:
|
tags = tags_path.read_text(encoding='utf-8').strip()
|
||||||
tags = f.readlines()[0].strip()
|
|
||||||
|
|
||||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
image_key = str(image_path) if args.full_path else image_path.stem
|
||||||
if image_key not in metadata:
|
if image_key not in metadata:
|
||||||
metadata[image_key] = {}
|
metadata[image_key] = {}
|
||||||
|
|
||||||
@ -42,8 +39,8 @@ def main(args):
|
|||||||
|
|
||||||
# metadataを書き出して終わり
|
# metadataを書き出して終わり
|
||||||
print(f"writing metadata: {args.out_json}")
|
print(f"writing metadata: {args.out_json}")
|
||||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
||||||
json.dump(metadata, f, indent=2)
|
|
||||||
print("done!")
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
@ -51,9 +48,14 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
parser.add_argument("--in_json", type=str,
|
||||||
|
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
||||||
parser.add_argument("--full_path", action="store_true",
|
parser.add_argument("--full_path", action="store_true",
|
||||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||||
|
parser.add_argument("--recursive", action="store_true",
|
||||||
|
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
||||||
|
parser.add_argument("--caption_extension", type=str, default=".txt",
|
||||||
|
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
|
||||||
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -1,20 +1,16 @@
|
|||||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
|
||||||
# (c) 2022 Kohya S. @kohya_ss
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import AutoencoderKL
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
|
import library.train_util as train_util
|
||||||
|
|
||||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
@ -26,6 +22,16 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_remove_corrupted(batch):
|
||||||
|
"""Collate function that allows to remove corrupted examples in the
|
||||||
|
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||||
|
The 'None's in the batch are removed.
|
||||||
|
"""
|
||||||
|
# Filter out all the Nones (corrupted examples)
|
||||||
|
batch = list(filter(lambda x: x is not None, batch))
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def get_latents(vae, images, weight_dtype):
|
def get_latents(vae, images, weight_dtype):
|
||||||
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||||
img_tensors = torch.stack(img_tensors)
|
img_tensors = torch.stack(img_tensors)
|
||||||
@ -35,9 +41,18 @@ def get_latents(vae, images, weight_dtype):
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
|
||||||
|
if is_full_path:
|
||||||
|
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||||
|
else:
|
||||||
|
base_name = image_key
|
||||||
|
if flip:
|
||||||
|
base_name += '_flip'
|
||||||
|
return os.path.join(data_dir, base_name)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
image_paths = train_util.glob_images(args.train_data_dir)
|
||||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if os.path.exists(args.in_json):
|
if os.path.exists(args.in_json):
|
||||||
@ -70,15 +85,56 @@ def main(args):
|
|||||||
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
||||||
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
||||||
img_ar_errors = []
|
img_ar_errors = []
|
||||||
for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)):
|
|
||||||
|
def process_batch(is_last):
|
||||||
|
for j in range(len(buckets_imgs)):
|
||||||
|
bucket = buckets_imgs[j]
|
||||||
|
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||||
|
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
||||||
|
|
||||||
|
for (image_key, _, _), latent in zip(bucket, latents):
|
||||||
|
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
|
||||||
|
np.savez(npz_file_name, latent)
|
||||||
|
|
||||||
|
# flip
|
||||||
|
if args.flip_aug:
|
||||||
|
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||||
|
|
||||||
|
for (image_key, _, _), latent in zip(bucket, latents):
|
||||||
|
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
|
||||||
|
np.savez(npz_file_name, latent)
|
||||||
|
|
||||||
|
bucket.clear()
|
||||||
|
|
||||||
|
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||||
|
if args.max_data_loader_n_workers is not None:
|
||||||
|
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||||
|
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
|
||||||
|
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||||
|
else:
|
||||||
|
data = [[(None, ip)] for ip in image_paths]
|
||||||
|
|
||||||
|
for data_entry in tqdm(data, smoothing=0.0):
|
||||||
|
if data_entry[0] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_tensor, image_path = data_entry[0]
|
||||||
|
if img_tensor is not None:
|
||||||
|
image = transforms.functional.to_pil_image(img_tensor)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
if image.mode != 'RGB':
|
||||||
|
image = image.convert("RGB")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||||
if image_key not in metadata:
|
if image_key not in metadata:
|
||||||
metadata[image_key] = {}
|
metadata[image_key] = {}
|
||||||
|
|
||||||
image = Image.open(image_path)
|
# 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
||||||
if image.mode != 'RGB':
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
aspect_ratio = image.width / image.height
|
aspect_ratio = image.width / image.height
|
||||||
ar_errors = bucket_aspect_ratios - aspect_ratio
|
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||||
bucket_id = np.abs(ar_errors).argmin()
|
bucket_id = np.abs(ar_errors).argmin()
|
||||||
@ -102,6 +158,25 @@ def main(args):
|
|||||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||||
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||||
|
|
||||||
|
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
||||||
|
if args.skip_existing:
|
||||||
|
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
|
||||||
|
if args.flip_aug:
|
||||||
|
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
|
||||||
|
|
||||||
|
found = True
|
||||||
|
for npz_file in npz_files:
|
||||||
|
if not os.path.exists(npz_file):
|
||||||
|
found = False
|
||||||
|
break
|
||||||
|
|
||||||
|
dat = np.load(npz_file)['arr_0']
|
||||||
|
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
|
||||||
|
found = False
|
||||||
|
break
|
||||||
|
if found:
|
||||||
|
continue
|
||||||
|
|
||||||
# 画像をリサイズしてトリミングする
|
# 画像をリサイズしてトリミングする
|
||||||
# PILにinter_areaがないのでcv2で……
|
# PILにinter_areaがないのでcv2で……
|
||||||
image = np.array(image)
|
image = np.array(image)
|
||||||
@ -123,25 +198,10 @@ def main(args):
|
|||||||
metadata[image_key]['train_resolution'] = reso
|
metadata[image_key]['train_resolution'] = reso
|
||||||
|
|
||||||
# バッチを推論するか判定して推論する
|
# バッチを推論するか判定して推論する
|
||||||
is_last = i == len(image_paths) - 1
|
process_batch(False)
|
||||||
for j in range(len(buckets_imgs)):
|
|
||||||
bucket = buckets_imgs[j]
|
|
||||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
|
||||||
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
|
||||||
|
|
||||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
# 残りを処理する
|
||||||
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
|
process_batch(True)
|
||||||
np.savez(os.path.join(args.train_data_dir, npz_file_name), latent)
|
|
||||||
|
|
||||||
# flip
|
|
||||||
if args.flip_aug:
|
|
||||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
|
||||||
|
|
||||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
|
||||||
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
|
|
||||||
np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent)
|
|
||||||
|
|
||||||
bucket.clear()
|
|
||||||
|
|
||||||
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
||||||
print(f"bucket {i} {reso}: {count}")
|
print(f"bucket {i} {reso}: {count}")
|
||||||
@ -162,8 +222,10 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||||
parser.add_argument("--v2", action='store_true',
|
parser.add_argument("--v2", action='store_true',
|
||||||
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||||
|
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||||
parser.add_argument("--max_resolution", type=str, default="512,512",
|
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||||
@ -174,6 +236,8 @@ if __name__ == '__main__':
|
|||||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||||
parser.add_argument("--flip_aug", action="store_true",
|
parser.add_argument("--flip_aug", action="store_true",
|
||||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||||
|
parser.add_argument("--skip_existing", action="store_true",
|
||||||
|
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -1,180 +0,0 @@
|
|||||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
|
||||||
# (c) 2022 Kohya S. @kohya_ss
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
from diffusers import AutoencoderKL
|
|
||||||
from PIL import Image
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
import library.model_util as model_util
|
|
||||||
|
|
||||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
IMAGE_TRANSFORMS = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_latents(vae, images, weight_dtype):
|
|
||||||
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
|
||||||
img_tensors = torch.stack(img_tensors)
|
|
||||||
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
|
||||||
with torch.no_grad():
|
|
||||||
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
# image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
|
||||||
# glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
|
||||||
# print(f"found {len(image_paths)} images.")
|
|
||||||
|
|
||||||
|
|
||||||
if os.path.exists(args.in_json):
|
|
||||||
print(f"loading existing metadata: {args.in_json}")
|
|
||||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
|
||||||
metadata = json.load(f)
|
|
||||||
else:
|
|
||||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
|
||||||
return
|
|
||||||
|
|
||||||
weight_dtype = torch.float32
|
|
||||||
if args.mixed_precision == "fp16":
|
|
||||||
weight_dtype = torch.float16
|
|
||||||
elif args.mixed_precision == "bf16":
|
|
||||||
weight_dtype = torch.bfloat16
|
|
||||||
|
|
||||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
|
||||||
vae.eval()
|
|
||||||
vae.to(DEVICE, dtype=weight_dtype)
|
|
||||||
|
|
||||||
# bucketのサイズを計算する
|
|
||||||
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
|
||||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
|
||||||
|
|
||||||
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
|
||||||
max_reso, args.min_bucket_reso, args.max_bucket_reso)
|
|
||||||
|
|
||||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
|
||||||
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
|
||||||
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
|
||||||
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
|
||||||
img_ar_errors = []
|
|
||||||
for i, image_path in enumerate(tqdm(metadata, smoothing=0.0)):
|
|
||||||
image_key = image_path
|
|
||||||
if image_key not in metadata:
|
|
||||||
metadata[image_key] = {}
|
|
||||||
|
|
||||||
image = Image.open(image_path)
|
|
||||||
if image.mode != 'RGB':
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
aspect_ratio = image.width / image.height
|
|
||||||
ar_errors = bucket_aspect_ratios - aspect_ratio
|
|
||||||
bucket_id = np.abs(ar_errors).argmin()
|
|
||||||
reso = bucket_resos[bucket_id]
|
|
||||||
ar_error = ar_errors[bucket_id]
|
|
||||||
img_ar_errors.append(abs(ar_error))
|
|
||||||
|
|
||||||
# どのサイズにリサイズするか→トリミングする方向で
|
|
||||||
if ar_error <= 0: # 横が長い→縦を合わせる
|
|
||||||
scale = reso[1] / image.height
|
|
||||||
else:
|
|
||||||
scale = reso[0] / image.width
|
|
||||||
|
|
||||||
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
|
|
||||||
|
|
||||||
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
|
|
||||||
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
|
|
||||||
|
|
||||||
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
|
||||||
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
|
||||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
|
||||||
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
|
||||||
|
|
||||||
# 画像をリサイズしてトリミングする
|
|
||||||
# PILにinter_areaがないのでcv2で……
|
|
||||||
image = np.array(image)
|
|
||||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
|
||||||
if resized_size[0] > reso[0]:
|
|
||||||
trim_size = resized_size[0] - reso[0]
|
|
||||||
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
|
||||||
elif resized_size[1] > reso[1]:
|
|
||||||
trim_size = resized_size[1] - reso[1]
|
|
||||||
image = image[trim_size//2:trim_size//2 + reso[1]]
|
|
||||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
|
||||||
|
|
||||||
# # debug
|
|
||||||
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
|
|
||||||
|
|
||||||
# バッチへ追加
|
|
||||||
buckets_imgs[bucket_id].append((image_key, reso, image))
|
|
||||||
bucket_counts[bucket_id] += 1
|
|
||||||
metadata[image_key]['train_resolution'] = reso
|
|
||||||
|
|
||||||
# バッチを推論するか判定して推論する
|
|
||||||
is_last = i == len(metadata) - 1
|
|
||||||
for j in range(len(buckets_imgs)):
|
|
||||||
bucket = buckets_imgs[j]
|
|
||||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
|
||||||
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
|
||||||
|
|
||||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
|
||||||
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
|
|
||||||
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name), latent)
|
|
||||||
|
|
||||||
# flip
|
|
||||||
if args.flip_aug:
|
|
||||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
|
||||||
|
|
||||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
|
||||||
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
|
|
||||||
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name + '_flip'), latent)
|
|
||||||
|
|
||||||
bucket.clear()
|
|
||||||
|
|
||||||
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
|
||||||
print(f"bucket {i} {reso}: {count}")
|
|
||||||
img_ar_errors = np.array(img_ar_errors)
|
|
||||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
|
||||||
|
|
||||||
# metadataを書き出して終わり
|
|
||||||
print(f"writing metadata: {args.out_json}")
|
|
||||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
|
||||||
json.dump(metadata, f, indent=2)
|
|
||||||
print("done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
|
||||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
|
||||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
|
||||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
|
||||||
parser.add_argument("--v2", action='store_true',
|
|
||||||
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
|
||||||
parser.add_argument("--max_resolution", type=str, default="512,512",
|
|
||||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
|
||||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
|
||||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
|
||||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
|
||||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
|
||||||
parser.add_argument("--full_path", action="store_true",
|
|
||||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
|
||||||
parser.add_argument("--flip_aug", action="store_true",
|
|
||||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
@ -1,6 +1,3 @@
|
|||||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
|
||||||
# (c) 2022 Kohya S. @kohya_ss
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
import glob
|
import glob
|
||||||
@ -12,32 +9,87 @@ from tqdm import tqdm
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import library.train_util as train_util
|
||||||
|
|
||||||
# from wd14 tagger
|
# from wd14 tagger
|
||||||
IMAGE_SIZE = 448
|
IMAGE_SIZE = 448
|
||||||
|
|
||||||
WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger'
|
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
||||||
|
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
|
||||||
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||||
SUB_DIR = "variables"
|
SUB_DIR = "variables"
|
||||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||||
CSV_FILE = FILES[-1]
|
CSV_FILE = FILES[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(image):
|
||||||
|
image = np.array(image)
|
||||||
|
image = image[:, :, ::-1] # RGB->BGR
|
||||||
|
|
||||||
|
# pad to square
|
||||||
|
size = max(image.shape[0:2])
|
||||||
|
pad_x = size - image.shape[1]
|
||||||
|
pad_y = size - image.shape[0]
|
||||||
|
pad_l = pad_x // 2
|
||||||
|
pad_t = pad_y // 2
|
||||||
|
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
|
||||||
|
|
||||||
|
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
||||||
|
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
||||||
|
|
||||||
|
image = image.astype(np.float32)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, image_paths):
|
||||||
|
self.images = image_paths
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img_path = self.images[idx]
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = Image.open(img_path).convert("RGB")
|
||||||
|
image = preprocess_image(image)
|
||||||
|
tensor = torch.tensor(image)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (tensor, img_path)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_remove_corrupted(batch):
|
||||||
|
"""Collate function that allows to remove corrupted examples in the
|
||||||
|
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||||
|
The 'None's in the batch are removed.
|
||||||
|
"""
|
||||||
|
# Filter out all the Nones (corrupted examples)
|
||||||
|
batch = list(filter(lambda x: x is not None, batch))
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||||
# depreacatedの警告が出るけどなくなったらその時
|
# depreacatedの警告が出るけどなくなったらその時
|
||||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||||
if not os.path.exists(args.model_dir) or args.force_download:
|
if not os.path.exists(args.model_dir) or args.force_download:
|
||||||
print("downloading wd14 tagger model from hf_hub")
|
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||||
for file in FILES:
|
for file in FILES:
|
||||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||||
for file in SUB_DIR_FILES:
|
for file in SUB_DIR_FILES:
|
||||||
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
|
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
|
||||||
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
||||||
|
else:
|
||||||
|
print("using existing wd14 tagger model")
|
||||||
|
|
||||||
# 画像を読み込む
|
# 画像を読み込む
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
image_paths = train_util.glob_images(args.train_data_dir)
|
||||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
print("loading model and labels")
|
print("loading model and labels")
|
||||||
@ -72,7 +124,7 @@ def main(args):
|
|||||||
# Everything else is tags: pick any where prediction confidence > threshold
|
# Everything else is tags: pick any where prediction confidence > threshold
|
||||||
tag_text = ""
|
tag_text = ""
|
||||||
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
||||||
if p >= args.thresh:
|
if p >= args.thresh and i < len(tags):
|
||||||
tag_text += ", " + tags[i]
|
tag_text += ", " + tags[i]
|
||||||
|
|
||||||
if len(tag_text) > 0:
|
if len(tag_text) > 0:
|
||||||
@ -83,30 +135,33 @@ def main(args):
|
|||||||
if args.debug:
|
if args.debug:
|
||||||
print(image_path, tag_text)
|
print(image_path, tag_text)
|
||||||
|
|
||||||
|
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||||
|
if args.max_data_loader_n_workers is not None:
|
||||||
|
dataset = ImageLoadingPrepDataset(image_paths)
|
||||||
|
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
||||||
|
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||||
|
else:
|
||||||
|
data = [[(None, ip)] for ip in image_paths]
|
||||||
|
|
||||||
b_imgs = []
|
b_imgs = []
|
||||||
for image_path in tqdm(image_paths, smoothing=0.0):
|
for data_entry in tqdm(data, smoothing=0.0):
|
||||||
img = Image.open(image_path) # cv2は日本語ファイル名で死ぬのとモード変換したいのでpillowで開く
|
for data in data_entry:
|
||||||
if img.mode != 'RGB':
|
if data is None:
|
||||||
img = img.convert("RGB")
|
continue
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1] # RGB->BGR
|
|
||||||
|
|
||||||
# pad to square
|
image, image_path = data
|
||||||
size = max(img.shape[0:2])
|
if image is not None:
|
||||||
pad_x = size - img.shape[1]
|
image = image.detach().numpy()
|
||||||
pad_y = size - img.shape[0]
|
else:
|
||||||
pad_l = pad_x // 2
|
try:
|
||||||
pad_t = pad_y // 2
|
image = Image.open(image_path)
|
||||||
img = np.pad(img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
|
if image.mode != 'RGB':
|
||||||
|
image = image.convert("RGB")
|
||||||
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
image = preprocess_image(image)
|
||||||
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
except Exception as e:
|
||||||
# cv2.imshow("img", img)
|
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||||
# cv2.waitKey()
|
continue
|
||||||
# cv2.destroyAllWindows()
|
b_imgs.append((image_path, image))
|
||||||
|
|
||||||
img = img.astype(np.float32)
|
|
||||||
b_imgs.append((image_path, img))
|
|
||||||
|
|
||||||
if len(b_imgs) >= args.batch_size:
|
if len(b_imgs) >= args.batch_size:
|
||||||
run_batch(b_imgs)
|
run_batch(b_imgs)
|
||||||
@ -121,7 +176,7 @@ def main(args):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument("--repo_id", type=str, default=WD14_TAGGER_REPO,
|
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
|
||||||
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
||||||
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
||||||
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
|
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
|
||||||
@ -129,6 +184,8 @@ if __name__ == '__main__':
|
|||||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
|
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
|
||||||
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||||
|
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||||
parser.add_argument("--caption_extention", type=str, default=None,
|
parser.add_argument("--caption_extention", type=str, default=None,
|
||||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||||
|
@ -292,7 +292,7 @@ def train_model(
|
|||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
image_num = len(
|
image_num = len(
|
||||||
[f for f in os.listdir(image_folder) if f.endswith('.npz')]
|
[f for f in os.listdir(image_folder) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')]
|
||||||
)
|
)
|
||||||
print(f'image_num = {image_num}')
|
print(f'image_num = {image_num}')
|
||||||
|
|
||||||
|
@ -470,6 +470,9 @@ class PipelineLike():
|
|||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
|
|
||||||
|
# Textual Inversion
|
||||||
|
self.token_replacements = {}
|
||||||
|
|
||||||
# CLIP guidance
|
# CLIP guidance
|
||||||
self.clip_guidance_scale = clip_guidance_scale
|
self.clip_guidance_scale = clip_guidance_scale
|
||||||
self.clip_image_guidance_scale = clip_image_guidance_scale
|
self.clip_image_guidance_scale = clip_image_guidance_scale
|
||||||
@ -484,6 +487,19 @@ class PipelineLike():
|
|||||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||||
|
|
||||||
|
# Textual Inversion
|
||||||
|
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||||
|
self.token_replacements[target_token_id] = rep_token_ids
|
||||||
|
|
||||||
|
def replace_token(self, tokens):
|
||||||
|
new_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if token in self.token_replacements:
|
||||||
|
new_tokens.extend(self.token_replacements[token])
|
||||||
|
else:
|
||||||
|
new_tokens.append(token)
|
||||||
|
return new_tokens
|
||||||
|
|
||||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||||
def enable_xformers_memory_efficient_attention(self):
|
def enable_xformers_memory_efficient_attention(self):
|
||||||
r"""
|
r"""
|
||||||
@ -1507,6 +1523,9 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
|||||||
for word, weight in texts_and_weights:
|
for word, weight in texts_and_weights:
|
||||||
# tokenize and discard the starting and the ending token
|
# tokenize and discard the starting and the ending token
|
||||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||||
|
|
||||||
|
token = pipe.replace_token(token)
|
||||||
|
|
||||||
text_token += token
|
text_token += token
|
||||||
# copy the weight by length of token
|
# copy the weight by length of token
|
||||||
text_weight += [weight] * len(token)
|
text_weight += [weight] * len(token)
|
||||||
@ -1826,12 +1845,12 @@ def main(args):
|
|||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
|
||||||
else:
|
else:
|
||||||
print("load Diffusers pretrained models")
|
print("load Diffusers pretrained models")
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
||||||
text_encoder = pipe.text_encoder
|
text_encoder = loading_pipe.text_encoder
|
||||||
vae = pipe.vae
|
vae = loading_pipe.vae
|
||||||
unet = pipe.unet
|
unet = loading_pipe.unet
|
||||||
tokenizer = pipe.tokenizer
|
tokenizer = loading_pipe.tokenizer
|
||||||
del pipe
|
del loading_pipe
|
||||||
|
|
||||||
# VAEを読み込む
|
# VAEを読み込む
|
||||||
if args.vae is not None:
|
if args.vae is not None:
|
||||||
@ -2039,6 +2058,44 @@ def main(args):
|
|||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
# Textual Inversionを処理する
|
||||||
|
if args.textual_inversion_embeddings:
|
||||||
|
token_ids_embeds = []
|
||||||
|
for embeds_file in args.textual_inversion_embeddings:
|
||||||
|
if model_util.is_safetensors(embeds_file):
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
data = load_file(embeds_file)
|
||||||
|
else:
|
||||||
|
data = torch.load(embeds_file, map_location="cpu")
|
||||||
|
|
||||||
|
embeds = next(iter(data.values()))
|
||||||
|
if type(embeds) != torch.Tensor:
|
||||||
|
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
|
||||||
|
|
||||||
|
num_vectors_per_token = embeds.size()[0]
|
||||||
|
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||||
|
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||||
|
|
||||||
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
|
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||||
|
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||||
|
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
|
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
||||||
|
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||||
|
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||||
|
|
||||||
|
if num_vectors_per_token > 1:
|
||||||
|
pipe.add_token_replacement(token_ids[0], token_ids)
|
||||||
|
|
||||||
|
token_ids_embeds.append((token_ids, embeds))
|
||||||
|
|
||||||
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||||
|
for token_ids, embeds in token_ids_embeds:
|
||||||
|
for token_id, embed in zip(token_ids, embeds):
|
||||||
|
token_embeds[token_id] = embed
|
||||||
|
|
||||||
# promptを取得する
|
# promptを取得する
|
||||||
if args.from_file is not None:
|
if args.from_file is not None:
|
||||||
print(f"reading prompts from {args.from_file}")
|
print(f"reading prompts from {args.from_file}")
|
||||||
@ -2157,8 +2214,8 @@ def main(args):
|
|||||||
os.makedirs(args.outdir, exist_ok=True)
|
os.makedirs(args.outdir, exist_ok=True)
|
||||||
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
||||||
|
|
||||||
for iter in range(args.n_iter):
|
for gen_iter in range(args.n_iter):
|
||||||
print(f"iteration {iter+1}/{args.n_iter}")
|
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||||
iter_seed = random.randint(0, 0x7fffffff)
|
iter_seed = random.randint(0, 0x7fffffff)
|
||||||
|
|
||||||
# バッチ処理の関数
|
# バッチ処理の関数
|
||||||
@ -2527,6 +2584,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
|
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
||||||
|
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
||||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||||
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
||||||
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
||||||
|
@ -45,6 +45,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
|
|||||||
# region dataset
|
# region dataset
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
||||||
|
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
|
||||||
|
|
||||||
|
|
||||||
class ImageInfo():
|
class ImageInfo():
|
||||||
@ -87,6 +88,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.enable_bucket = False
|
self.enable_bucket = False
|
||||||
self.min_bucket_reso = None
|
self.min_bucket_reso = None
|
||||||
self.max_bucket_reso = None
|
self.max_bucket_reso = None
|
||||||
|
self.tag_frequency = {}
|
||||||
self.bucket_info = None
|
self.bucket_info = None
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
@ -115,6 +117,16 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.replacements = {}
|
self.replacements = {}
|
||||||
|
|
||||||
|
def set_tag_frequency(self, dir_name, captions):
|
||||||
|
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||||
|
self.tag_frequency[dir_name] = frequency_for_dir
|
||||||
|
for caption in captions:
|
||||||
|
for tag in caption.split(","):
|
||||||
|
if tag and not tag.isspace():
|
||||||
|
tag = tag.lower()
|
||||||
|
frequency = frequency_for_dir.get(tag, 0)
|
||||||
|
frequency_for_dir[tag] = frequency + 1
|
||||||
|
|
||||||
def disable_token_padding(self):
|
def disable_token_padding(self):
|
||||||
self.token_padding_disabled = True
|
self.token_padding_disabled = True
|
||||||
|
|
||||||
@ -240,6 +252,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
||||||
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
|
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
|
||||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
|
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
|
||||||
|
# only show bucket info if there is an actual image in it
|
||||||
|
if len(img_keys) > 0:
|
||||||
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
|
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
|
||||||
|
|
||||||
img_ar_errors = np.array(img_ar_errors)
|
img_ar_errors = np.array(img_ar_errors)
|
||||||
@ -247,7 +261,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||||
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||||
|
|
||||||
|
|
||||||
# 参照用indexを作る
|
# 参照用indexを作る
|
||||||
self.buckets_indices: list(BucketBatchIndex) = []
|
self.buckets_indices: list(BucketBatchIndex) = []
|
||||||
for bucket_index, bucket in enumerate(self.buckets):
|
for bucket_index, bucket in enumerate(self.buckets):
|
||||||
@ -545,6 +558,8 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
cap_for_img = read_caption(img_path)
|
cap_for_img = read_caption(img_path)
|
||||||
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
||||||
|
|
||||||
|
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
||||||
|
|
||||||
return n_repeats, img_paths, captions
|
return n_repeats, img_paths, captions
|
||||||
|
|
||||||
print("prepare train images.")
|
print("prepare train images.")
|
||||||
@ -553,10 +568,13 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for dir in train_dirs:
|
for dir in train_dirs:
|
||||||
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
||||||
num_train_images += n_repeats * len(img_paths)
|
num_train_images += n_repeats * len(img_paths)
|
||||||
|
|
||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
||||||
self.register_image(info)
|
self.register_image(info)
|
||||||
|
|
||||||
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
|
|
||||||
print(f"{num_train_images} train images with repeating.")
|
print(f"{num_train_images} train images with repeating.")
|
||||||
self.num_train_images = num_train_images
|
self.num_train_images = num_train_images
|
||||||
|
|
||||||
@ -570,9 +588,11 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for dir in reg_dirs:
|
for dir in reg_dirs:
|
||||||
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
||||||
num_reg_images += n_repeats * len(img_paths)
|
num_reg_images += n_repeats * len(img_paths)
|
||||||
|
|
||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
||||||
reg_infos.append(info)
|
reg_infos.append(info)
|
||||||
|
|
||||||
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
|
|
||||||
print(f"{num_reg_images} reg images.")
|
print(f"{num_reg_images} reg images.")
|
||||||
@ -617,6 +637,7 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.train_data_dir = train_data_dir
|
self.train_data_dir = train_data_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
tags_list = []
|
||||||
for image_key, img_md in metadata.items():
|
for image_key, img_md in metadata.items():
|
||||||
# path情報を作る
|
# path情報を作る
|
||||||
if os.path.exists(image_key):
|
if os.path.exists(image_key):
|
||||||
@ -633,6 +654,7 @@ class FineTuningDataset(BaseDataset):
|
|||||||
caption = tags
|
caption = tags
|
||||||
elif tags is not None and len(tags) > 0:
|
elif tags is not None and len(tags) > 0:
|
||||||
caption = caption + ', ' + tags
|
caption = caption + ', ' + tags
|
||||||
|
tags_list.append(tags)
|
||||||
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
||||||
|
|
||||||
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
||||||
@ -646,7 +668,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.num_train_images = len(metadata) * dataset_repeats
|
self.num_train_images = len(metadata) * dataset_repeats
|
||||||
self.num_reg_images = 0
|
self.num_reg_images = 0
|
||||||
|
|
||||||
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
||||||
|
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||||
|
|
||||||
# check existence of all npz files
|
# check existence of all npz files
|
||||||
if not self.color_aug:
|
if not self.color_aug:
|
||||||
@ -667,6 +690,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
|
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
|
||||||
elif not npz_all:
|
elif not npz_all:
|
||||||
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
||||||
|
if self.flip_aug:
|
||||||
|
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||||
for image_info in self.image_data.values():
|
for image_info in self.image_data.values():
|
||||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||||
|
|
||||||
@ -756,15 +781,30 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def glob_images(dir, base):
|
def glob_images(directory, base="*"):
|
||||||
img_paths = []
|
img_paths = []
|
||||||
for ext in IMAGE_EXTENSIONS:
|
for ext in IMAGE_EXTENSIONS:
|
||||||
if base == '*':
|
if base == '*':
|
||||||
img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext)))
|
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
||||||
else:
|
else:
|
||||||
img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext))))
|
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
||||||
|
# img_paths = list(set(img_paths)) # 重複を排除
|
||||||
|
# img_paths.sort()
|
||||||
return img_paths
|
return img_paths
|
||||||
|
|
||||||
|
|
||||||
|
def glob_images_pathlib(dir_path, recursive):
|
||||||
|
image_paths = []
|
||||||
|
if recursive:
|
||||||
|
for ext in IMAGE_EXTENSIONS:
|
||||||
|
image_paths += list(dir_path.rglob('*' + ext))
|
||||||
|
else:
|
||||||
|
for ext in IMAGE_EXTENSIONS:
|
||||||
|
image_paths += list(dir_path.glob('*' + ext))
|
||||||
|
# image_paths = list(set(image_paths)) # 重複を排除
|
||||||
|
# image_paths.sort()
|
||||||
|
return image_paths
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
@ -1495,5 +1535,30 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|||||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region 前処理用
|
||||||
|
|
||||||
|
|
||||||
|
class ImageLoadingDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, image_paths):
|
||||||
|
self.images = image_paths
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img_path = self.images[idx]
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = Image.open(img_path).convert("RGB")
|
||||||
|
# convert to tensor temporarily so dataloader will accept it
|
||||||
|
tensor_pil = transforms.functional.pil_to_tensor(image)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (tensor_pil, img_path)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
accelerate==0.15.0
|
accelerate==0.15.0
|
||||||
transformers==4.25.1
|
transformers==4.26.0
|
||||||
ftfy
|
ftfy
|
||||||
albumentations
|
albumentations
|
||||||
opencv-python
|
opencv-python
|
||||||
@ -9,14 +9,13 @@ pytorch_lightning
|
|||||||
bitsandbytes==0.35.0
|
bitsandbytes==0.35.0
|
||||||
tensorboard
|
tensorboard
|
||||||
safetensors==0.2.6
|
safetensors==0.2.6
|
||||||
gradio==3.15.0
|
gradio
|
||||||
altair
|
altair
|
||||||
easygui
|
easygui
|
||||||
tk
|
|
||||||
# for BLIP captioning
|
# for BLIP captioning
|
||||||
requests
|
requests
|
||||||
timm
|
timm==0.4.12
|
||||||
fairscale
|
fairscale==0.4.4
|
||||||
# for WD14 captioning
|
# for WD14 captioning
|
||||||
tensorflow<2.11
|
tensorflow<2.11
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
|
66
tools/resize_images_to_resolution.py
Normal file
66
tools/resize_images_to_resolution.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
import math
|
||||||
|
|
||||||
|
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
|
||||||
|
# Calculate max_pixels from max_resolution string
|
||||||
|
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||||
|
|
||||||
|
# Create destination folder if it does not exist
|
||||||
|
if not os.path.exists(dst_img_folder):
|
||||||
|
os.makedirs(dst_img_folder)
|
||||||
|
|
||||||
|
# Iterate through all files in src_img_folder
|
||||||
|
for filename in os.listdir(src_img_folder):
|
||||||
|
# Check if the image is png, jpg or webp
|
||||||
|
if not filename.endswith(('.png', '.jpg', '.webp')):
|
||||||
|
# Copy the file to the destination folder if not png, jpg or webp
|
||||||
|
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Load image
|
||||||
|
img = cv2.imread(os.path.join(src_img_folder, filename))
|
||||||
|
|
||||||
|
# Calculate current number of pixels
|
||||||
|
current_pixels = img.shape[0] * img.shape[1]
|
||||||
|
|
||||||
|
# Check if the image needs resizing
|
||||||
|
if current_pixels > max_pixels:
|
||||||
|
# Calculate scaling factor
|
||||||
|
scale_factor = max_pixels / current_pixels
|
||||||
|
|
||||||
|
# Calculate new dimensions
|
||||||
|
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||||
|
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||||
|
|
||||||
|
# Resize image
|
||||||
|
img = cv2.resize(img, (new_width, new_height))
|
||||||
|
|
||||||
|
# Calculate the new height and width that are divisible by divisible_by
|
||||||
|
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
||||||
|
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
||||||
|
|
||||||
|
# Center crop the image to the calculated dimensions
|
||||||
|
y = int((img.shape[0] - new_height) / 2)
|
||||||
|
x = int((img.shape[1] - new_width) / 2)
|
||||||
|
img = img[y:y + new_height, x:x + new_width]
|
||||||
|
|
||||||
|
# Save resized image in dst_img_folder
|
||||||
|
cv2.imwrite(os.path.join(dst_img_folder, filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||||
|
|
||||||
|
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution')
|
||||||
|
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
|
||||||
|
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
|
||||||
|
parser.add_argument('--max_resolution', type=str, help='Maximum resolution in the format "512x512"', default="512x512")
|
||||||
|
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=2)
|
||||||
|
args = parser.parse_args()
|
||||||
|
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -1,3 +1,6 @@
|
|||||||
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from typing import Optional, Union
|
||||||
import importlib
|
import importlib
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
@ -40,9 +43,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
|||||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||||
|
|
||||||
from typing import Optional, Union
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
|
||||||
|
|
||||||
def get_scheduler_fix(
|
def get_scheduler_fix(
|
||||||
name: Union[str, SchedulerType],
|
name: Union[str, SchedulerType],
|
||||||
@ -135,7 +135,7 @@ def train(args):
|
|||||||
train_util.debug_dataset(train_dataset)
|
train_util.debug_dataset(train_dataset)
|
||||||
return
|
return
|
||||||
if len(train_dataset) == 0:
|
if len(train_dataset) == 0:
|
||||||
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
||||||
return
|
return
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
@ -335,6 +335,7 @@ def train(args):
|
|||||||
"ss_keep_tokens": args.keep_tokens,
|
"ss_keep_tokens": args.keep_tokens,
|
||||||
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||||
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
||||||
|
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
||||||
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
||||||
"ss_training_comment": args.training_comment # will not be updated after training
|
"ss_training_comment": args.training_comment # will not be updated after training
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user