Publish v15
This commit is contained in:
parent
30b4be5680
commit
e8db30b9d1
65
README.md
65
README.md
@ -21,29 +21,20 @@ Give unrestricted script access to powershell so venv can work:
|
||||
Open a regular Powershell terminal and type the following inside:
|
||||
|
||||
```powershell
|
||||
# Clone the Kohya_ss repository
|
||||
git clone https://github.com/bmaltais/kohya_ss.git
|
||||
|
||||
# Navigate to the newly cloned directory
|
||||
cd kohya_ss
|
||||
|
||||
# Create a virtual environment using the system-site-packages option
|
||||
python -m venv --system-site-packages venv
|
||||
|
||||
# Activate the virtual environment
|
||||
.\venv\Scripts\activate
|
||||
|
||||
# Install the required packages
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
|
||||
# Copy the necessary files to the virtual environment's site-packages directory
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
|
||||
# Configure the accelerate utility
|
||||
accelerate config
|
||||
|
||||
```
|
||||
@ -285,20 +276,22 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n
|
||||
## Options list
|
||||
|
||||
```txt
|
||||
usage: train_db_fixed.py [-h] [--v2] [--v_parameterization] [--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH]
|
||||
[--fine_tuning] [--shuffle_caption] [--caption_extention CAPTION_EXTENTION]
|
||||
usage: train_db_fixed.py [-h] [--v2] [--v_parameterization]
|
||||
[--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH] [--fine_tuning]
|
||||
[--shuffle_caption] [--caption_extention CAPTION_EXTENTION]
|
||||
[--caption_extension CAPTION_EXTENSION] [--train_data_dir TRAIN_DATA_DIR]
|
||||
[--reg_data_dir REG_DATA_DIR] [--dataset_repeats DATASET_REPEATS] [--output_dir OUTPUT_DIR]
|
||||
[--save_every_n_epochs SAVE_EVERY_N_EPOCHS] [--save_state] [--resume RESUME]
|
||||
[--use_safetensors] [--save_every_n_epochs SAVE_EVERY_N_EPOCHS] [--save_state] [--resume RESUME]
|
||||
[--prior_loss_weight PRIOR_LOSS_WEIGHT] [--no_token_padding]
|
||||
[--stop_text_encoder_training STOP_TEXT_ENCODER_TRAINING] [--color_aug] [--flip_aug]
|
||||
[--face_crop_aug_range FACE_CROP_AUG_RANGE] [--random_crop] [--debug_dataset]
|
||||
[--resolution RESOLUTION] [--train_batch_size TRAIN_BATCH_SIZE] [--use_8bit_adam] [--mem_eff_attn]
|
||||
[--xformers] [--cache_latents] [--enable_bucket] [--min_bucket_reso MIN_BUCKET_RESO]
|
||||
[--max_bucket_reso MAX_BUCKET_RESO] [--learning_rate LEARNING_RATE]
|
||||
[--max_train_steps MAX_TRAIN_STEPS] [--seed SEED] [--gradient_checkpointing]
|
||||
[--mixed_precision {no,fp16,bf16}] [--save_precision {None,float,fp16,bf16}] [--clip_skip CLIP_SKIP]
|
||||
[--logging_dir LOGGING_DIR] [--lr_scheduler LR_SCHEDULER] [--lr_warmup_steps LR_WARMUP_STEPS]
|
||||
[--resolution RESOLUTION] [--train_batch_size TRAIN_BATCH_SIZE] [--use_8bit_adam]
|
||||
[--mem_eff_attn] [--xformers] [--vae VAE] [--cache_latents] [--enable_bucket]
|
||||
[--min_bucket_reso MIN_BUCKET_RESO] [--max_bucket_reso MAX_BUCKET_RESO]
|
||||
[--learning_rate LEARNING_RATE] [--max_train_steps MAX_TRAIN_STEPS] [--seed SEED]
|
||||
[--gradient_checkpointing] [--mixed_precision {no,fp16,bf16}]
|
||||
[--save_precision {None,float,fp16,bf16}] [--clip_skip CLIP_SKIP] [--logging_dir LOGGING_DIR]
|
||||
[--log_prefix LOG_PREFIX] [--lr_scheduler LR_SCHEDULER] [--lr_warmup_steps LR_WARMUP_STEPS]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
@ -310,7 +303,7 @@ options:
|
||||
--fine_tuning fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする
|
||||
--shuffle_caption shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする
|
||||
--caption_extention CAPTION_EXTENTION
|
||||
extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを残し てあります)
|
||||
extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを 残してあります)
|
||||
--caption_extension CAPTION_EXTENSION
|
||||
extension of caption files / 読み込むcaptionファイルの拡張子
|
||||
--train_data_dir TRAIN_DATA_DIR
|
||||
@ -320,15 +313,18 @@ options:
|
||||
--dataset_repeats DATASET_REPEATS
|
||||
repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数
|
||||
--output_dir OUTPUT_DIR
|
||||
directory to output trained model (default format is same to input) /
|
||||
学習後のモデル出力先ディレクトリ(デフォルトの保存形式は読み込んだ形式と同じ)
|
||||
directory to output trained model / 学習後のモデル出力先ディレクトリ
|
||||
--use_safetensors use safetensors format for StableDiffusion checkpoint /
|
||||
StableDiffusionのcheckpointをsafetensors形式で保存する
|
||||
--save_every_n_epochs SAVE_EVERY_N_EPOCHS
|
||||
save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存します
|
||||
--save_state save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する
|
||||
--save_state save training state additionally (including optimizer states etc.) /
|
||||
optimizerなど学習状態も含めたstateを追加で保存する
|
||||
--resume RESUME saved state to resume training / 学習再開するモデルのstate
|
||||
--prior_loss_weight PRIOR_LOSS_WEIGHT
|
||||
loss weight for regularization images / 正則化画像のlossの重み
|
||||
--no_token_padding disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)
|
||||
--no_token_padding disable token padding (same as Diffuser's DreamBooth) /
|
||||
トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)
|
||||
--stop_text_encoder_training STOP_TEXT_ENCODER_TRAINING
|
||||
steps to stop text encoder training / Text Encoderの学習を止めるステップ数
|
||||
--color_aug enable weak color augmentation / 学習時に色合いのaugmentationを有効にする
|
||||
@ -340,13 +336,14 @@ options:
|
||||
ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)
|
||||
--debug_dataset show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)
|
||||
--resolution RESOLUTION
|
||||
resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)
|
||||
resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高 さ'指定)
|
||||
--train_batch_size TRAIN_BATCH_SIZE
|
||||
batch size for training (1 means one train or reg data, not train/reg pair) /
|
||||
学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)
|
||||
--use_8bit_adam use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)
|
||||
--mem_eff_attn use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う
|
||||
--xformers use xformers for CrossAttention / CrossAttentionにxformersを使う
|
||||
--vae VAE path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ
|
||||
--cache_latents cache latents to reduce memory (augmentations must be disabled) /
|
||||
メモリ削減のためにlatentをcacheする(augmentationは使用不可)
|
||||
--enable_bucket enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする
|
||||
@ -365,17 +362,29 @@ options:
|
||||
use mixed precision / 混合精度を使う場合、その精度
|
||||
--save_precision {None,float,fp16,bf16}
|
||||
precision in saving (available in StableDiffusion checkpoint) /
|
||||
保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)
|
||||
--clip_skip CLIP_SKIP
|
||||
use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を 用いる(nは1以上)
|
||||
--logging_dir LOGGING_DIR
|
||||
enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する
|
||||
enable logging and output TensorBoard log to this directory /
|
||||
ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する
|
||||
--log_prefix LOG_PREFIX
|
||||
add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列
|
||||
--lr_scheduler LR_SCHEDULER
|
||||
scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial,
|
||||
constant (default), constant_with_warmup
|
||||
scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts,
|
||||
polynomial, constant (default), constant_with_warmup
|
||||
--lr_warmup_steps LR_WARMUP_STEPS
|
||||
Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)
|
||||
Number of steps for the warmup in the lr scheduler (default is 0) /
|
||||
学習率のスケジューラをウォームアップするステップ数(デフォルト0)
|
||||
```
|
||||
|
||||
## Change history
|
||||
|
||||
* 12/05 (v15) update:
|
||||
- The script has been divided into two parts
|
||||
- Support for SafeTensors format has been added. Install SafeTensors with `pip install safetensors`. The script will automatically detect the format based on the file extension when loading. Use the `--use_safetensors` option if you want to save the model as safetensor.
|
||||
- The vae option has been added to load a VAE model separately.
|
||||
- The log_prefix option has been added to allow adding a custom string to the log directory name before the date and time.
|
||||
* 11/30 (v13) update:
|
||||
- fix training text encoder at specified step (`--stop_text_encoder_training=<step #>`) that was causing both Unet and text encoder training to stop completely at the specified step rather than continue without text encoding training.
|
||||
* 11/29 (v12) update:
|
||||
|
@ -1,10 +1,3 @@
|
||||
# Diffusers Fine Tuning
|
||||
|
||||
This subfolder provide all the required tools to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29
|
||||
|
||||
## Releases
|
||||
|
||||
11/23 (v3):
|
||||
- Added WD14Tagger tagging script.
|
||||
- A log output function has been added to the fine_tune.py. Also, fixed the double shuffling of data.
|
||||
- Fixed misspelling of options for each script (caption_extention→caption_extension will work for the time being, even if it remains outdated).
|
||||
Code has been moved to dedicated repo at: https://github.com/bmaltais/kohya_diffusers_fine_tuning
|
@ -1,125 +0,0 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def clean_tags(image_key, tags):
|
||||
# replace '_' to ' '
|
||||
tags = tags.replace('_', ' ')
|
||||
|
||||
# remove rating: deepdanbooruのみ
|
||||
tokens = tags.split(", rating")
|
||||
if len(tokens) == 1:
|
||||
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
||||
# print("no rating:")
|
||||
# print(f"{image_key} {tags}")
|
||||
pass
|
||||
else:
|
||||
if len(tokens) > 2:
|
||||
print("multiple ratings:")
|
||||
print(f"{image_key} {tags}")
|
||||
tags = tokens[0]
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
# 上から順に検索、置換される
|
||||
# ('置換元文字列', '置換後文字列')
|
||||
CAPTION_REPLACEMENTS = [
|
||||
('anime anime', 'anime'),
|
||||
('young ', ''),
|
||||
('anime girl', 'girl'),
|
||||
('cartoon female', 'girl'),
|
||||
('cartoon lady', 'girl'),
|
||||
('cartoon character', 'girl'), # a or ~s
|
||||
('cartoon woman', 'girl'),
|
||||
('cartoon women', 'girls'),
|
||||
('cartoon girl', 'girl'),
|
||||
('anime female', 'girl'),
|
||||
('anime lady', 'girl'),
|
||||
('anime character', 'girl'), # a or ~s
|
||||
('anime woman', 'girl'),
|
||||
('anime women', 'girls'),
|
||||
('lady', 'girl'),
|
||||
('female', 'girl'),
|
||||
('woman', 'girl'),
|
||||
('women', 'girls'),
|
||||
('people', 'girls'),
|
||||
('person', 'girl'),
|
||||
('a cartoon figure', 'a figure'),
|
||||
('a cartoon image', 'an image'),
|
||||
('a cartoon picture', 'a picture'),
|
||||
('an anime cartoon image', 'an image'),
|
||||
('a cartoon anime drawing', 'a drawing'),
|
||||
('a cartoon drawing', 'a drawing'),
|
||||
('girl girl', 'girl'),
|
||||
]
|
||||
|
||||
|
||||
def clean_caption(caption):
|
||||
for rf, rt in CAPTION_REPLACEMENTS:
|
||||
replaced = True
|
||||
while replaced:
|
||||
bef = caption
|
||||
caption = caption.replace(rf, rt)
|
||||
replaced = bef != caption
|
||||
return caption
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print("no metadata / メタデータファイルがありません")
|
||||
return
|
||||
|
||||
print("cleaning captions and tags.")
|
||||
for image_path in tqdm(image_paths):
|
||||
tags_path = os.path.splitext(image_path)[0] + '.txt'
|
||||
with open(tags_path, "rt", encoding='utf-8') as f:
|
||||
tags = f.readlines()[0].strip()
|
||||
|
||||
image_key = os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
print(f"image not in metadata / メタデータに画像がありません: {image_path}")
|
||||
return
|
||||
|
||||
tags = metadata[image_key].get('tags')
|
||||
if tags is None:
|
||||
print(f"image does not have tags / メタデータにタグがありません: {image_path}")
|
||||
else:
|
||||
metadata[image_key]['tags'] = clean_tags(image_key, tags)
|
||||
|
||||
caption = metadata[image_key].get('caption')
|
||||
if caption is None:
|
||||
print(f"image does not have caption / メタデータにキャプションがありません: {image_path}")
|
||||
else:
|
||||
metadata[image_key]['caption'] = clean_caption(caption)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
# parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -1,968 +0,0 @@
|
||||
# v2: select precision for saved checkpoint
|
||||
# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
|
||||
# v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model
|
||||
|
||||
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
||||
# License:
|
||||
# Copyright 2022 Kohya S. @kohya_ss
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# License of included scripts:
|
||||
# Diffusers: ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE
|
||||
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
import importlib
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
|
||||
import fine_tuning_utils
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||
|
||||
# checkpointファイル名
|
||||
LAST_CHECKPOINT_NAME = "last.ckpt"
|
||||
LAST_STATE_NAME = "last-state"
|
||||
LAST_DIFFUSERS_DIR_NAME = "last"
|
||||
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
|
||||
EPOCH_STATE_NAME = "epoch-{:06d}-state"
|
||||
EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
|
||||
|
||||
class FineTuningDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, metadata, train_data_dir, batch_size, tokenizer, max_token_length, shuffle_caption, dataset_repeats, debug) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.metadata = metadata
|
||||
self.train_data_dir = train_data_dir
|
||||
self.batch_size = batch_size
|
||||
self.tokenizer: CLIPTokenizer = tokenizer
|
||||
self.max_token_length = max_token_length
|
||||
self.shuffle_caption = shuffle_caption
|
||||
self.debug = debug
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
print("make buckets")
|
||||
|
||||
# 最初に数を数える
|
||||
self.bucket_resos = set()
|
||||
for img_md in metadata.values():
|
||||
if 'train_resolution' in img_md:
|
||||
self.bucket_resos.add(tuple(img_md['train_resolution']))
|
||||
self.bucket_resos = list(self.bucket_resos)
|
||||
self.bucket_resos.sort()
|
||||
print(f"number of buckets: {len(self.bucket_resos)}")
|
||||
|
||||
reso_to_index = {}
|
||||
for i, reso in enumerate(self.bucket_resos):
|
||||
reso_to_index[reso] = i
|
||||
|
||||
# bucketに割り当てていく
|
||||
self.buckets = [[] for _ in range(len(self.bucket_resos))]
|
||||
n = 1 if dataset_repeats is None else dataset_repeats
|
||||
images_count = 0
|
||||
for image_key, img_md in metadata.items():
|
||||
if 'train_resolution' not in img_md:
|
||||
continue
|
||||
if not os.path.exists(os.path.join(self.train_data_dir, image_key + '.npz')):
|
||||
continue
|
||||
|
||||
reso = tuple(img_md['train_resolution'])
|
||||
for _ in range(n):
|
||||
self.buckets[reso_to_index[reso]].append(image_key)
|
||||
images_count += n
|
||||
|
||||
# 参照用indexを作る
|
||||
self.buckets_indices = []
|
||||
for bucket_index, bucket in enumerate(self.buckets):
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
self.buckets_indices.append((bucket_index, batch_index))
|
||||
|
||||
self.shuffle_buckets()
|
||||
self._length = len(self.buckets_indices)
|
||||
self.images_count = images_count
|
||||
|
||||
def show_buckets(self):
|
||||
for i, (reso, bucket) in enumerate(zip(self.bucket_resos, self.buckets)):
|
||||
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
|
||||
def shuffle_buckets(self):
|
||||
random.shuffle(self.buckets_indices)
|
||||
for bucket in self.buckets:
|
||||
random.shuffle(bucket)
|
||||
|
||||
def load_latent(self, image_key):
|
||||
return np.load(os.path.join(self.train_data_dir, image_key + '.npz'))['arr_0']
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index == 0:
|
||||
self.shuffle_buckets()
|
||||
|
||||
bucket = self.buckets[self.buckets_indices[index][0]]
|
||||
image_index = self.buckets_indices[index][1] * self.batch_size
|
||||
|
||||
input_ids_list = []
|
||||
latents_list = []
|
||||
captions = []
|
||||
for image_key in bucket[image_index:image_index + self.batch_size]:
|
||||
img_md = self.metadata[image_key]
|
||||
caption = img_md.get('caption')
|
||||
tags = img_md.get('tags')
|
||||
|
||||
if caption is None:
|
||||
caption = tags
|
||||
elif tags is not None and len(tags) > 0:
|
||||
caption = caption + ', ' + tags
|
||||
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{image_key}"
|
||||
|
||||
latents = self.load_latent(image_key)
|
||||
|
||||
if self.shuffle_caption:
|
||||
tokens = caption.strip().split(",")
|
||||
random.shuffle(tokens)
|
||||
caption = ",".join(tokens).strip()
|
||||
|
||||
captions.append(caption)
|
||||
|
||||
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
||||
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
||||
|
||||
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
iids_list = []
|
||||
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||
# v1
|
||||
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
|
||||
ids_chunk = (input_ids[0].unsqueeze(0),
|
||||
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0))
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
iids_list.append(ids_chunk)
|
||||
else:
|
||||
# v2
|
||||
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||||
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
|
||||
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
|
||||
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0)) # PAD or EOS
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
|
||||
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||||
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||||
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
|
||||
ids_chunk[-1] = self.tokenizer.eos_token_id
|
||||
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||||
if ids_chunk[1] == self.tokenizer.pad_token_id:
|
||||
ids_chunk[1] = self.tokenizer.eos_token_id
|
||||
|
||||
iids_list.append(ids_chunk)
|
||||
|
||||
input_ids = torch.stack(iids_list) # 3,77
|
||||
|
||||
input_ids_list.append(input_ids)
|
||||
latents_list.append(torch.FloatTensor(latents))
|
||||
|
||||
example = {}
|
||||
example['input_ids'] = torch.stack(input_ids_list)
|
||||
example['latents'] = torch.stack(latents_list)
|
||||
if self.debug:
|
||||
example['image_keys'] = bucket[image_index:image_index + self.batch_size]
|
||||
example['captions'] = captions
|
||||
return example
|
||||
|
||||
|
||||
def save_hypernetwork(output_file, hypernetwork):
|
||||
state_dict = hypernetwork.get_state_dict()
|
||||
torch.save(state_dict, output_file)
|
||||
|
||||
|
||||
def train(args):
|
||||
fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training
|
||||
|
||||
# その他のオプション設定を確認する
|
||||
if args.v_parameterization and not args.v2:
|
||||
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
# モデル形式のオプション設定を確認する
|
||||
# v11からDiffUsersから直接落としてくるのもOK(ただし認証がいるやつは未対応)、またv11からDiffUsersも途中保存に対応した
|
||||
use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
|
||||
|
||||
# 乱数系列を初期化する
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
# tokenizerを読み込む
|
||||
print("prepare tokenizer")
|
||||
if args.v2:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
||||
|
||||
if args.max_token_length is not None:
|
||||
print(f"update token length: {args.max_token_length}")
|
||||
|
||||
# datasetを用意する
|
||||
print("prepare dataset")
|
||||
train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size,
|
||||
tokenizer, args.max_token_length, args.shuffle_caption, args.dataset_repeats, args.debug_dataset)
|
||||
|
||||
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
|
||||
print(f"Total images / 画像数: {train_dataset.images_count}")
|
||||
if args.debug_dataset:
|
||||
train_dataset.show_buckets()
|
||||
i = 0
|
||||
for example in train_dataset:
|
||||
print(f"image: {example['image_keys']}")
|
||||
print(f"captions: {example['captions']}")
|
||||
print(f"latents: {example['latents'].shape}")
|
||||
print(f"input_ids: {example['input_ids'].shape}")
|
||||
print(example['input_ids'])
|
||||
i += 1
|
||||
if i >= 8:
|
||||
break
|
||||
return
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
if args.logging_dir is None:
|
||||
log_with = None
|
||||
logging_dir = None
|
||||
else:
|
||||
log_with = "tensorboard"
|
||||
logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
save_dtype = None
|
||||
if args.save_precision == "fp16":
|
||||
save_dtype = torch.float16
|
||||
elif args.save_precision == "bf16":
|
||||
save_dtype = torch.bfloat16
|
||||
elif args.save_precision == "float":
|
||||
save_dtype = torch.float32
|
||||
|
||||
# モデルを読み込む
|
||||
if use_stable_diffusion_format:
|
||||
print("load StableDiffusion checkpoint")
|
||||
text_encoder, _, unet = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(
|
||||
args.v2, args.pretrained_model_name_or_path)
|
||||
else:
|
||||
print("load Diffusers pretrained models")
|
||||
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
||||
# , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる
|
||||
text_encoder = pipe.text_encoder
|
||||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
if not fine_tuning:
|
||||
# Hypernetwork
|
||||
print("import hypernetwork module:", args.hypernetwork_module)
|
||||
hyp_module = importlib.import_module(args.hypernetwork_module)
|
||||
|
||||
hypernetwork = hyp_module.Hypernetwork()
|
||||
|
||||
if args.hypernetwork_weights is not None:
|
||||
print("load hypernetwork weights from:", args.hypernetwork_weights)
|
||||
hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu')
|
||||
success = hypernetwork.load_from_state_dict(hyp_sd)
|
||||
assert success, "hypernetwork weights loading failed."
|
||||
|
||||
print("apply hypernetwork")
|
||||
hypernetwork.apply_to_diffusers(None, text_encoder, unet)
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
training_models = []
|
||||
if fine_tuning:
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
training_models.append(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
print("enable text encoder training")
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
training_models.append(text_encoder)
|
||||
else:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False) # text encoderは学習しない
|
||||
text_encoder.eval()
|
||||
else:
|
||||
unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない
|
||||
unet.requires_grad_(False)
|
||||
unet.eval()
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
training_models.append(hypernetwork)
|
||||
|
||||
for m in training_models:
|
||||
m.requires_grad_(True)
|
||||
params = []
|
||||
for m in training_models:
|
||||
params.extend(m.parameters())
|
||||
params_to_optimize = params
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
|
||||
# 8-bit Adamを使う
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print("use 8-bit Adam optimizer")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if fine_tuning:
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
else:
|
||||
unet, hypernetwork, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, hypernetwork, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
print(f"resume training from state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
print("running training / 学習開始")
|
||||
print(f" num examples / サンプル数: {train_dataset.images_count}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}")
|
||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
# v4で更新:clip_sample=Falseに
|
||||
# Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
|
||||
# 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ
|
||||
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000, clip_sample=False)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork")
|
||||
|
||||
# 以下 train_dreambooth.py からほぼコピペ
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# with torch.no_grad():
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
||||
|
||||
if args.clip_skip is None:
|
||||
encoder_hidden_states = text_encoder(input_ids)[0]
|
||||
else:
|
||||
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||
|
||||
# bs*3, 77, 768 or 1024
|
||||
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
||||
|
||||
if args.max_token_length is not None:
|
||||
if args.v2:
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
||||
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
else:
|
||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
||||
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
|
||||
# 11/29現在v predictionのコードがDiffusersにcommitされたがリリースされていないので独自コードを使う
|
||||
# 実装の中身は同じ模様
|
||||
|
||||
# こうしたい:
|
||||
# target = noise_scheduler.get_v(latents, noise, timesteps)
|
||||
|
||||
# StabilityAiのddpm.pyのコード:
|
||||
# elif self.parameterization == "v":
|
||||
# target = self.get_v(x_start, noise, t)
|
||||
# ...
|
||||
# def get_v(self, x, noise, t):
|
||||
# return (
|
||||
# extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
|
||||
# extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||||
# )
|
||||
|
||||
# scheduling_ddim.pyのコード:
|
||||
# elif self.config.prediction_type == "v_prediction":
|
||||
# pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
# # predict V
|
||||
# model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
|
||||
# これでいいかな?:
|
||||
alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape
|
||||
beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1))
|
||||
target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step+1)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
||||
accelerator.log(logs, step=epoch+1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
|
||||
print("saving checkpoint.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
|
||||
|
||||
if fine_tuning:
|
||||
if use_stable_diffusion_format:
|
||||
fine_tuning_utils.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
|
||||
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
|
||||
else:
|
||||
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
|
||||
accelerator.unwrap_model(unet), args.pretrained_model_name_or_path, save_dtype)
|
||||
else:
|
||||
save_hypernetwork(ckpt_file, accelerator.unwrap_model(hypernetwork))
|
||||
|
||||
if args.save_state:
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
if fine_tuning:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
else:
|
||||
hypernetwork = accelerator.unwrap_model(hypernetwork)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
print("saving last state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if fine_tuning:
|
||||
if use_stable_diffusion_format:
|
||||
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||
fine_tuning_utils.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
|
||||
else:
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
print(f"save trained model as Diffusers to {args.output_dir}")
|
||||
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
||||
args.pretrained_model_name_or_path, save_dtype)
|
||||
else:
|
||||
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
save_hypernetwork(ckpt_file, hypernetwork)
|
||||
|
||||
print("model saved.")
|
||||
|
||||
|
||||
# region モジュール入れ替え部
|
||||
"""
|
||||
高速化のためのモジュール入れ替え
|
||||
"""
|
||||
|
||||
# FlashAttentionを使うCrossAttention
|
||||
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
||||
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
||||
|
||||
# constants
|
||||
|
||||
EPSILON = 1e-6
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# flash attention forwards and backwards
|
||||
|
||||
# https://arxiv.org/abs/2205.14135
|
||||
|
||||
|
||||
class FlashAttentionFunction(torch.autograd.function.Function):
|
||||
@ staticmethod
|
||||
@ torch.no_grad()
|
||||
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
||||
""" Algorithm 2 in the paper """
|
||||
|
||||
device = q.device
|
||||
dtype = q.dtype
|
||||
max_neg_value = -torch.finfo(q.dtype).max
|
||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||
|
||||
o = torch.zeros_like(q)
|
||||
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
||||
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
||||
|
||||
scale = (q.shape[-1] ** -0.5)
|
||||
|
||||
if not exists(mask):
|
||||
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
||||
else:
|
||||
mask = rearrange(mask, 'b n -> b 1 1 n')
|
||||
mask = mask.split(q_bucket_size, dim=-1)
|
||||
|
||||
row_splits = zip(
|
||||
q.split(q_bucket_size, dim=-2),
|
||||
o.split(q_bucket_size, dim=-2),
|
||||
mask,
|
||||
all_row_sums.split(q_bucket_size, dim=-2),
|
||||
all_row_maxes.split(q_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||
|
||||
col_splits = zip(
|
||||
k.split(k_bucket_size, dim=-2),
|
||||
v.split(k_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for k_ind, (kc, vc) in enumerate(col_splits):
|
||||
k_start_index = k_ind * k_bucket_size
|
||||
|
||||
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
||||
|
||||
if exists(row_mask):
|
||||
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
||||
|
||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
||||
device=device).triu(q_start_index - k_start_index + 1)
|
||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||
|
||||
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
||||
attn_weights -= block_row_maxes
|
||||
exp_weights = torch.exp(attn_weights)
|
||||
|
||||
if exists(row_mask):
|
||||
exp_weights.masked_fill_(~row_mask, 0.)
|
||||
|
||||
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
||||
|
||||
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
||||
|
||||
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
|
||||
|
||||
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
||||
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
||||
|
||||
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
||||
|
||||
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
||||
|
||||
row_maxes.copy_(new_row_maxes)
|
||||
row_sums.copy_(new_row_sums)
|
||||
|
||||
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
||||
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
||||
|
||||
return o
|
||||
|
||||
@ staticmethod
|
||||
@ torch.no_grad()
|
||||
def backward(ctx, do):
|
||||
""" Algorithm 4 in the paper """
|
||||
|
||||
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
|
||||
device = q.device
|
||||
|
||||
max_neg_value = -torch.finfo(q.dtype).max
|
||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||
|
||||
dq = torch.zeros_like(q)
|
||||
dk = torch.zeros_like(k)
|
||||
dv = torch.zeros_like(v)
|
||||
|
||||
row_splits = zip(
|
||||
q.split(q_bucket_size, dim=-2),
|
||||
o.split(q_bucket_size, dim=-2),
|
||||
do.split(q_bucket_size, dim=-2),
|
||||
mask,
|
||||
l.split(q_bucket_size, dim=-2),
|
||||
m.split(q_bucket_size, dim=-2),
|
||||
dq.split(q_bucket_size, dim=-2)
|
||||
)
|
||||
|
||||
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||
|
||||
col_splits = zip(
|
||||
k.split(k_bucket_size, dim=-2),
|
||||
v.split(k_bucket_size, dim=-2),
|
||||
dk.split(k_bucket_size, dim=-2),
|
||||
dv.split(k_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
||||
k_start_index = k_ind * k_bucket_size
|
||||
|
||||
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
||||
|
||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
||||
device=device).triu(q_start_index - k_start_index + 1)
|
||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||
|
||||
exp_attn_weights = torch.exp(attn_weights - mc)
|
||||
|
||||
if exists(row_mask):
|
||||
exp_attn_weights.masked_fill_(~row_mask, 0.)
|
||||
|
||||
p = exp_attn_weights / lc
|
||||
|
||||
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
|
||||
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
|
||||
|
||||
D = (doc * oc).sum(dim=-1, keepdims=True)
|
||||
ds = p * scale * (dp - D)
|
||||
|
||||
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
|
||||
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
|
||||
|
||||
dqc.add_(dq_chunk)
|
||||
dkc.add_(dk_chunk)
|
||||
dvc.add_(dv_chunk)
|
||||
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
if mem_eff_attn:
|
||||
replace_unet_cross_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
replace_unet_cross_attn_to_xformers()
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
print("Replace CrossAttention.forward to use FlashAttention")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, x, context=None, mask=None):
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
|
||||
context = context if context is not None else x
|
||||
context = context.to(x.dtype)
|
||||
|
||||
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
context_k = context_k.to(x.dtype)
|
||||
context_v = context_v.to(x.dtype)
|
||||
else:
|
||||
context_k = context
|
||||
context_v = context
|
||||
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
||||
|
||||
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
# diffusers 0.6.0
|
||||
if type(self.to_out) is torch.nn.Sequential:
|
||||
return self.to_out(out)
|
||||
|
||||
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
|
||||
out = self.to_out[0](out)
|
||||
out = self.to_out[1](out)
|
||||
return out
|
||||
|
||||
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
print("Replace CrossAttention.forward to use xformers")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
def forward_xformers(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
|
||||
context = default(context, x)
|
||||
context = context.to(x.dtype)
|
||||
|
||||
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
context_k = context_k.to(x.dtype)
|
||||
context_v = context_v.to(x.dtype)
|
||||
else:
|
||||
context_k = context
|
||||
context_v = context
|
||||
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
|
||||
# diffusers 0.6.0
|
||||
if type(self.to_out) is torch.nn.Sequential:
|
||||
return self.to_out(out)
|
||||
|
||||
# diffusers 0.7.0~
|
||||
out = self.to_out[0](out)
|
||||
out = self.to_out[1](out)
|
||||
return out
|
||||
|
||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
# endregion
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||
parser.add_argument("--v_parameterization", action='store_true',
|
||||
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
||||
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
||||
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
||||
parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル")
|
||||
parser.add_argument("--shuffle_caption", action="store_true",
|
||||
help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする")
|
||||
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--dataset_repeats", type=int, default=None, help="num times to repeat dataset / 学習にデータセットを繰り返す回数")
|
||||
parser.add_argument("--output_dir", type=str, default=None,
|
||||
help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument("--hypernetwork_module", type=str, default=None,
|
||||
help='train hypernetwork instead of fine tuning, module to use / fine tuningの代わりにHypernetworkの学習をする場合、そのモジュール')
|
||||
parser.add_argument("--hypernetwork_weights", type=str, default=None,
|
||||
help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)')
|
||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||
parser.add_argument("--save_state", action="store_true",
|
||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||
parser.add_argument("--resume", type=str, default=None,
|
||||
help="saved state to resume training / 学習再開するモデルのstate")
|
||||
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
||||
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
||||
parser.add_argument("--train_batch_size", type=int, default=1,
|
||||
help="batch size for training / 学習時のバッチサイズ")
|
||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
parser.add_argument("--mem_eff_attn", action="store_true",
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
||||
parser.add_argument("--xformers", action="store_true",
|
||||
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
|
||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)")
|
||||
parser.add_argument("--clip_skip", type=int, default=None,
|
||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||
parser.add_argument("--debug_dataset", action="store_true",
|
||||
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
||||
parser.add_argument("--logging_dir", type=str, default=None,
|
||||
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
@ -1,96 +0,0 @@
|
||||
# NAI compatible
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
def __init__(self, dim, multiplier=1.0):
|
||||
super().__init__()
|
||||
|
||||
linear1 = torch.nn.Linear(dim, dim * 2)
|
||||
linear2 = torch.nn.Linear(dim * 2, dim)
|
||||
linear1.weight.data.normal_(mean=0.0, std=0.01)
|
||||
linear1.bias.data.zero_()
|
||||
linear2.weight.data.normal_(mean=0.0, std=0.01)
|
||||
linear2.bias.data.zero_()
|
||||
linears = [linear1, linear2]
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
self.multiplier = multiplier
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * self.multiplier
|
||||
|
||||
|
||||
class Hypernetwork(torch.nn.Module):
|
||||
enable_sizes = [320, 640, 768, 1280]
|
||||
# return self.modules[Hypernetwork.enable_sizes.index(size)]
|
||||
|
||||
def __init__(self, multiplier=1.0) -> None:
|
||||
super().__init__()
|
||||
self.modules = []
|
||||
for size in Hypernetwork.enable_sizes:
|
||||
self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
|
||||
self.register_module(f"{size}_0", self.modules[-1][0])
|
||||
self.register_module(f"{size}_1", self.modules[-1][1])
|
||||
|
||||
def apply_to_stable_diffusion(self, text_encoder, vae, unet):
|
||||
blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
|
||||
for block in blocks:
|
||||
for subblk in block:
|
||||
if 'SpatialTransformer' in str(type(subblk)):
|
||||
for tf_block in subblk.transformer_blocks:
|
||||
for attn in [tf_block.attn1, tf_block.attn2]:
|
||||
size = attn.context_dim
|
||||
if size in Hypernetwork.enable_sizes:
|
||||
attn.hypernetwork = self
|
||||
else:
|
||||
attn.hypernetwork = None
|
||||
|
||||
def apply_to_diffusers(self, text_encoder, vae, unet):
|
||||
blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
|
||||
for block in blocks:
|
||||
if hasattr(block, 'attentions'):
|
||||
for subblk in block.attentions:
|
||||
if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
|
||||
for tf_block in subblk.transformer_blocks:
|
||||
for attn in [tf_block.attn1, tf_block.attn2]:
|
||||
size = attn.to_k.in_features
|
||||
if size in Hypernetwork.enable_sizes:
|
||||
attn.hypernetwork = self
|
||||
else:
|
||||
attn.hypernetwork = None
|
||||
return True # TODO error checking
|
||||
|
||||
def forward(self, x, context):
|
||||
size = context.shape[-1]
|
||||
assert size in Hypernetwork.enable_sizes
|
||||
module = self.modules[Hypernetwork.enable_sizes.index(size)]
|
||||
return module[0].forward(context), module[1].forward(context)
|
||||
|
||||
def load_from_state_dict(self, state_dict):
|
||||
# old ver to new ver
|
||||
changes = {
|
||||
'linear1.bias': 'linear.0.bias',
|
||||
'linear1.weight': 'linear.0.weight',
|
||||
'linear2.bias': 'linear.1.bias',
|
||||
'linear2.weight': 'linear.1.weight',
|
||||
}
|
||||
for key_from, key_to in changes.items():
|
||||
if key_from in state_dict:
|
||||
state_dict[key_to] = state_dict[key_from]
|
||||
del state_dict[key_from]
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
|
||||
self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
|
||||
return True
|
||||
|
||||
def get_state_dict(self):
|
||||
state_dict = {}
|
||||
for i, size in enumerate(Hypernetwork.enable_sizes):
|
||||
sd0 = self.modules[i][0].state_dict()
|
||||
sd1 = self.modules[i][1].state_dict()
|
||||
state_dict[size] = [sd0, sd1]
|
||||
return state_dict
|
@ -1,97 +0,0 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from models.blip import blip_decoder
|
||||
# from Salesforce_BLIP.models.blip import blip_decoder
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
print(f"loading BLIP caption: {args.caption_weights}")
|
||||
image_size = 384
|
||||
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large')
|
||||
model.eval()
|
||||
model = model.to(DEVICE)
|
||||
print("BLIP loaded")
|
||||
|
||||
# 正方形でいいのか? という気がするがソースがそうなので
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
if args.beam_search:
|
||||
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
|
||||
max_length=args.max_length, min_length=args.min_length)
|
||||
else:
|
||||
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
||||
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
|
||||
b_imgs = []
|
||||
for image_path in tqdm(image_paths):
|
||||
raw_image = Image.open(image_path)
|
||||
if raw_image.mode != "RGB":
|
||||
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
|
||||
raw_image = raw_image.convert("RGB")
|
||||
|
||||
image = transform(raw_image)
|
||||
b_imgs.append((image_path, image))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("caption_weights", type=str,
|
||||
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--beam_search", action="store_true",
|
||||
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
@ -1,68 +0,0 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if args.in_json is not None:
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
||||
else:
|
||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
|
||||
print("merge caption texts to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
with open(caption_path, "rt", encoding='utf-8') as f:
|
||||
caption = f.readlines()[0].strip()
|
||||
|
||||
image_key = os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
# if args.verify_caption:
|
||||
# print(f"image not in metadata / メタデータに画像がありません: {image_path}")
|
||||
# return
|
||||
metadata[image_key] = {}
|
||||
# elif args.verify_caption and 'caption' not in metadata[image_key]:
|
||||
# print(f"no caption in metadata / メタデータにcaptionがありません: {image_path}")
|
||||
# return
|
||||
|
||||
metadata[image_key]['caption'] = caption
|
||||
if args.debug:
|
||||
print(image_key, caption)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
@ -1,61 +0,0 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if args.in_json is not None:
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||
else:
|
||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
|
||||
print("merge tags to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
tags_path = os.path.splitext(image_path)[0] + '.txt'
|
||||
with open(tags_path, "rt", encoding='utf-8') as f:
|
||||
tags = f.readlines()[0].strip()
|
||||
|
||||
image_key = os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
# if args.verify_caption:
|
||||
# print(f"image not in metadata / メタデータに画像がありません: {image_path}")
|
||||
# return
|
||||
metadata[image_key] = {}
|
||||
# elif args.verify_caption and 'caption' not in metadata[image_key]:
|
||||
# print(f"no caption in metadata / メタデータにcaptionがありません: {image_path}")
|
||||
# return
|
||||
|
||||
metadata[image_key]['tags'] = tags
|
||||
if args.debug:
|
||||
print(image_key, tags)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
# parser.add_argument("--verify_caption", action="store_true", help="verify caption exists / メタデータにすでにcaptionが存在することを確認する")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -1,175 +0,0 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKL
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
import fine_tuning_utils
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_latents(vae, images, weight_dtype):
|
||||
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||
img_tensors = torch.stack(img_tensors)
|
||||
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
||||
return latents
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
# モデル形式のオプション設定を確認する
|
||||
use_stable_diffusion_format = os.path.isfile(args.model_name_or_path)
|
||||
|
||||
# モデルを読み込む
|
||||
if use_stable_diffusion_format:
|
||||
print("load StableDiffusion checkpoint")
|
||||
_, vae, _ = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_name_or_path)
|
||||
else:
|
||||
print("load Diffusers pretrained models")
|
||||
vae = AutoencoderKL.from_pretrained(args.model_name_or_path, subfolder="vae")
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_resos, bucket_aspect_ratios = fine_tuning_utils.make_bucket_resolutions(
|
||||
max_reso, args.min_bucket_reso, args.max_bucket_reso)
|
||||
|
||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
||||
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
||||
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
||||
img_ar_errors = []
|
||||
for i, image_path in enumerate(tqdm(image_paths)):
|
||||
image_key = os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
|
||||
aspect_ratio = image.width / image.height
|
||||
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||
bucket_id = np.abs(ar_errors).argmin()
|
||||
reso = bucket_resos[bucket_id]
|
||||
ar_error = ar_errors[bucket_id]
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
|
||||
# どのサイズにリサイズするか→トリミングする方向で
|
||||
if ar_error <= 0: # 横が長い→縦を合わせる
|
||||
scale = reso[1] / image.height
|
||||
else:
|
||||
scale = reso[0] / image.width
|
||||
|
||||
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
|
||||
|
||||
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
|
||||
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
|
||||
|
||||
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
||||
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
|
||||
# 画像をリサイズしてトリミングする
|
||||
# PILにinter_areaがないのでcv2で……
|
||||
image = np.array(image)
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||
if resized_size[0] > reso[0]:
|
||||
trim_size = resized_size[0] - reso[0]
|
||||
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||
elif resized_size[1] > reso[1]:
|
||||
trim_size = resized_size[1] - reso[1]
|
||||
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
|
||||
# # debug
|
||||
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
|
||||
|
||||
# バッチへ追加
|
||||
buckets_imgs[bucket_id].append((image_key, reso, image))
|
||||
bucket_counts[bucket_id] += 1
|
||||
metadata[image_key]['train_resolution'] = reso
|
||||
|
||||
# バッチを推論するか判定して推論する
|
||||
is_last = i == len(image_paths) - 1
|
||||
for j in range(len(buckets_imgs)):
|
||||
bucket = buckets_imgs[j]
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
||||
|
||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||
np.savez(os.path.join(args.train_data_dir, image_key), latent)
|
||||
|
||||
bucket.clear()
|
||||
|
||||
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -1,6 +0,0 @@
|
||||
transformers>=4.21.0
|
||||
ftfy
|
||||
albumentations
|
||||
opencv-python
|
||||
einops
|
||||
pytorch_lightning
|
@ -1,107 +0,0 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from tensorflow.keras.models import load_model
|
||||
from Utils import dbimutils
|
||||
|
||||
|
||||
# from wd14 tagger
|
||||
IMAGE_SIZE = 448
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
print("loading model and labels")
|
||||
model = load_model(args.model)
|
||||
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
with open(args.tag_csv, "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
l = [row for row in reader]
|
||||
header = l[0] # tag_id,name,category,count
|
||||
rows = l[1:]
|
||||
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
|
||||
|
||||
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
|
||||
|
||||
# 推論する
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
# 最初の4つはratingなので無視する
|
||||
# # First 4 labels are actually ratings: pick one with argmax
|
||||
# ratings_names = label_names[:4]
|
||||
# rating_index = ratings_names["probs"].argmax()
|
||||
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
||||
|
||||
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||
# Everything else is tags: pick any where prediction confidence > threshold
|
||||
tag_text = ""
|
||||
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
||||
if p >= args.thresh:
|
||||
tag_text += ", " + tags[i]
|
||||
|
||||
if len(tag_text) > 0:
|
||||
tag_text = tag_text[2:] # 最初の ", " を消す
|
||||
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(tag_text + '\n')
|
||||
if args.debug:
|
||||
print(image_path, tag_text)
|
||||
|
||||
b_imgs = []
|
||||
for image_path in tqdm(image_paths):
|
||||
img = dbimutils.smart_imread(image_path)
|
||||
img = dbimutils.smart_24bit(img)
|
||||
img = dbimutils.make_square(img, IMAGE_SIZE)
|
||||
img = dbimutils.smart_resize(img, IMAGE_SIZE)
|
||||
img = img.astype(np.float32)
|
||||
b_imgs.append((image_path, img))
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--model", type=str, default="networks/ViTB16_11_03_2022_07h05m53s",
|
||||
help="model path to load / 読み込むモデルファイル")
|
||||
parser.add_argument("--tag_csv", type=str, default="2022_0000_0899_6549/selected_tags.csv",
|
||||
help="csv file for tags / タグ一覧のCSVファイル")
|
||||
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
@ -1,12 +1,12 @@
|
||||
# v1: split from train_db_fixed.py.
|
||||
# v2: support safetensors
|
||||
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
from transformers import CLIPTextModel
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
|
||||
|
||||
# region checkpoint変換、読み込み、書き込み ###############################
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||
NUM_TRAIN_TIMESTEPS = 1000
|
||||
@ -37,7 +37,7 @@ V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
||||
|
||||
|
||||
# region StableDiffusion->Diffusersの変換コード
|
||||
# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
|
||||
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
@ -243,21 +243,21 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
||||
for layer_id in range(num_middle_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||
output_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
||||
for layer_id in range(num_output_blocks)
|
||||
}
|
||||
|
||||
@ -332,14 +332,22 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
# オリジナル:
|
||||
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||
|
||||
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
||||
for l in output_block_list.values():
|
||||
l.sort()
|
||||
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.bias"
|
||||
]
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
|
||||
# Clear attentions as they have been attributed above.
|
||||
if len(attentions) == 2:
|
||||
@ -377,6 +385,9 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
# if len(vae_state_dict) == 0:
|
||||
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
||||
# vae_state_dict = checkpoint
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
@ -617,7 +628,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
|
||||
|
||||
# region Diffusers->StableDiffusion の変換コード
|
||||
# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
|
||||
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
||||
|
||||
def conv_transformer_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
@ -723,8 +734,90 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
vae_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("nin_shortcut", "conv_shortcut"),
|
||||
("norm_out", "conv_norm_out"),
|
||||
("mid.attn_1.", "mid_block.attentions.0."),
|
||||
]
|
||||
|
||||
for i in range(4):
|
||||
# down_blocks have two resnets
|
||||
for j in range(2):
|
||||
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
||||
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
||||
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
||||
|
||||
if i < 3:
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
||||
sd_downsample_prefix = f"down.{i}.downsample."
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
("q.", "query."),
|
||||
("k.", "key."),
|
||||
("v.", "value."),
|
||||
("proj_out.", "proj_attn."),
|
||||
]
|
||||
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in vae_conversion_map:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
if "attentions" in k:
|
||||
for sd_part, hf_part in vae_conversion_map_attn:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
||||
weights_to_convert = ["q", "k", "v", "proj_out"]
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
# print(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region 自作のモデル読み書き
|
||||
|
||||
def is_safetensors(path):
|
||||
return os.path.splitext(path)[1].lower() == '.safetensors'
|
||||
|
||||
|
||||
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
||||
@ -734,8 +827,16 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
||||
]
|
||||
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
state_dict = checkpoint["state_dict"]
|
||||
if is_safetensors(ckpt_path):
|
||||
checkpoint = None
|
||||
state_dict = load_file(ckpt_path, "cpu")
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
checkpoint = None
|
||||
|
||||
key_reps = []
|
||||
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
||||
@ -748,12 +849,12 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
|
||||
return checkpoint
|
||||
return checkpoint, state_dict
|
||||
|
||||
|
||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
if dtype is not None:
|
||||
for k, v in state_dict.items():
|
||||
if type(v) is torch.Tensor:
|
||||
@ -810,7 +911,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||
return text_model, vae, unet
|
||||
|
||||
|
||||
def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
|
||||
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
||||
def convert_key(key):
|
||||
# position_idsの除去
|
||||
if ".position_ids" in key:
|
||||
@ -866,35 +967,66 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
|
||||
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||
new_sd[new_key] = value
|
||||
|
||||
# 最後の層などを捏造するか
|
||||
if make_dummy_weights:
|
||||
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
||||
keys = list(new_sd.keys())
|
||||
for key in keys:
|
||||
if key.startswith("transformer.resblocks.22."):
|
||||
new_sd[key.replace(".22.", ".23.")] = new_sd[key]
|
||||
|
||||
# Diffusersに含まれない重みを作っておく
|
||||
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
||||
new_sd['logit_scale'] = torch.tensor(1)
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
|
||||
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む
|
||||
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
||||
if ckpt_path is not None:
|
||||
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
||||
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
if checkpoint is None: # safetensors または state_dictのckpt
|
||||
checkpoint = {}
|
||||
strict = False
|
||||
else:
|
||||
strict = True
|
||||
if "state_dict" in state_dict:
|
||||
del state_dict["state_dict"]
|
||||
else:
|
||||
# 新しく作る
|
||||
checkpoint = {}
|
||||
state_dict = {}
|
||||
strict = False
|
||||
|
||||
def assign_new_sd(prefix, sd):
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
assert key in state_dict, f"Illegal key in save SD: {key}"
|
||||
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
||||
if save_dtype is not None:
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
||||
assign_new_sd("model.diffusion_model.", unet_state_dict)
|
||||
update_sd("model.diffusion_model.", unet_state_dict)
|
||||
|
||||
# Convert the text encoder model
|
||||
if v2:
|
||||
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict())
|
||||
assign_new_sd("cond_stage_model.model.", text_enc_dict)
|
||||
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
||||
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
||||
update_sd("cond_stage_model.model.", text_enc_dict)
|
||||
else:
|
||||
text_enc_dict = text_encoder.state_dict()
|
||||
assign_new_sd("cond_stage_model.transformer.", text_enc_dict)
|
||||
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
||||
|
||||
# Convert the VAE
|
||||
if vae is not None:
|
||||
vae_dict = convert_vae_state_dict(vae.state_dict())
|
||||
update_sd("first_stage_model.", vae_dict)
|
||||
|
||||
# Put together new checkpoint
|
||||
key_count = len(state_dict.keys())
|
||||
new_ckpt = {'state_dict': state_dict}
|
||||
|
||||
if 'epoch' in checkpoint:
|
||||
@ -905,14 +1037,22 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
||||
new_ckpt['epoch'] = epochs
|
||||
new_ckpt['global_step'] = steps
|
||||
|
||||
torch.save(new_ckpt, output_file)
|
||||
if is_safetensors(output_file):
|
||||
# TODO Tensor以外のdictの値を削除したほうがいいか
|
||||
save_file(state_dict, output_file)
|
||||
else:
|
||||
torch.save(new_ckpt, output_file)
|
||||
|
||||
return key_count
|
||||
|
||||
|
||||
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, save_dtype):
|
||||
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None):
|
||||
if vae is None:
|
||||
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
||||
pipeline = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
vae=AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae"),
|
||||
vae=vae,
|
||||
scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
|
||||
tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
|
||||
safety_checker=None,
|
||||
@ -921,6 +1061,62 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
|
||||
|
||||
VAE_PREFIX = "first_stage_model."
|
||||
|
||||
|
||||
def load_vae(vae_id, dtype):
|
||||
print(f"load VAE: {vae_id}")
|
||||
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||
# Diffusers local/remote
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
||||
except EnvironmentError as e:
|
||||
print(f"exception occurs in loading vae: {e}")
|
||||
print("retry with subfolder='vae'")
|
||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
||||
return vae
|
||||
|
||||
# local
|
||||
vae_config = create_vae_diffusers_config()
|
||||
|
||||
if vae_id.endswith(".bin"):
|
||||
# SD 1.5 VAE on Huggingface
|
||||
vae_sd = torch.load(vae_id, map_location="cpu")
|
||||
converted_vae_checkpoint = vae_sd
|
||||
else:
|
||||
# StableDiffusion
|
||||
vae_model = torch.load(vae_id, map_location="cpu")
|
||||
vae_sd = vae_model['state_dict']
|
||||
|
||||
# vae only or full model
|
||||
full_model = False
|
||||
for vae_key in vae_sd:
|
||||
if vae_key.startswith(VAE_PREFIX):
|
||||
full_model = True
|
||||
break
|
||||
if not full_model:
|
||||
sd = {}
|
||||
for key, value in vae_sd.items():
|
||||
sd[VAE_PREFIX + key] = value
|
||||
vae_sd = sd
|
||||
del sd
|
||||
|
||||
# Convert the VAE model.
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae
|
||||
|
||||
|
||||
def get_epoch_ckpt_name(use_safetensors, epoch):
|
||||
return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt")
|
||||
|
||||
|
||||
def get_last_ckpt_name(use_safetensors):
|
||||
return f"last" + (".safetensors" if use_safetensors else ".ckpt")
|
||||
|
||||
# endregion
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
accelerate==0.14.0
|
||||
transformers>=4.21.0
|
||||
transformers==4.25.1
|
||||
ftfy
|
||||
albumentations
|
||||
opencv-python
|
||||
@ -8,3 +8,4 @@ diffusers[torch]==0.9.0
|
||||
pytorch_lightning
|
||||
bitsandbytes==0.35.0
|
||||
tensorboard
|
||||
safetensors==0.2.5
|
69
tools/caption.py
Normal file
69
tools/caption.py
Normal file
@ -0,0 +1,69 @@
|
||||
# This script will create the caption text files in the specified folder using the specified file pattern and caption text.
|
||||
#
|
||||
# eg: python caption.py D:\some\folder\location "*.png, *.jpg, *.webp" "some caption text"
|
||||
|
||||
import argparse
|
||||
# import glob
|
||||
# import os
|
||||
from pathlib import Path
|
||||
|
||||
def create_caption_files(image_folder: str, file_pattern: str, caption_text: str, caption_file_ext: str, overwrite: bool):
|
||||
# Split the file patterns string and strip whitespace from each pattern
|
||||
patterns = [pattern.strip() for pattern in file_pattern.split(",")]
|
||||
|
||||
# Create a Path object for the image folder
|
||||
folder = Path(image_folder)
|
||||
|
||||
# Iterate over the file patterns
|
||||
for pattern in patterns:
|
||||
# Use the glob method to match the file patterns
|
||||
files = folder.glob(pattern)
|
||||
|
||||
# Iterate over the matched files
|
||||
for file in files:
|
||||
# Check if a text file with the same name as the current file exists in the folder
|
||||
txt_file = file.with_suffix(caption_file_ext)
|
||||
if not txt_file.exists() or overwrite:
|
||||
# Create a text file with the caption text in the folder, if it does not already exist
|
||||
# or if the overwrite argument is True
|
||||
with open(txt_file, "w") as f:
|
||||
f.write(caption_text)
|
||||
|
||||
def main():
|
||||
# Define command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("image_folder", type=str, help="the folder where the image files are located")
|
||||
parser.add_argument("--file_pattern", type=str, default="*.png, *.jpg, *.jpeg, *.webp", help="the pattern to match the image file names")
|
||||
parser.add_argument("--caption_file_ext", type=str, default=".caption", help="the caption file extension.")
|
||||
parser.add_argument("--overwrite", action="store_true", default=False, help="whether to overwrite existing caption files")
|
||||
|
||||
# Create a mutually exclusive group for the caption_text and caption_file arguments
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--caption_text", type=str, help="the text to include in the caption files")
|
||||
group.add_argument("--caption_file", type=argparse.FileType("r"), help="the file containing the text to include in the caption files")
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
image_folder = args.image_folder
|
||||
file_pattern = args.file_pattern
|
||||
caption_file_ext = args.caption_file_ext
|
||||
overwrite = args.overwrite
|
||||
|
||||
# Get the caption text from either the caption_text or caption_file argument
|
||||
if args.caption_text:
|
||||
caption_text = args.caption_text
|
||||
elif args.caption_file:
|
||||
caption_text = args.caption_file.read()
|
||||
|
||||
# Create a Path object for the image folder
|
||||
folder = Path(image_folder)
|
||||
|
||||
# Check if the image folder exists and is a directory
|
||||
if not folder.is_dir():
|
||||
raise ValueError(f"{image_folder} is not a valid directory.")
|
||||
|
||||
# Create the caption files
|
||||
create_caption_files(image_folder, file_pattern, caption_text, caption_file_ext, overwrite)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
57
tools/convert_images_to_hq_jpg.py
Normal file
57
tools/convert_images_to_hq_jpg.py
Normal file
@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def main():
|
||||
# Define the command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("directory", type=str,
|
||||
help="the directory containing the images to be converted")
|
||||
parser.add_argument("--in_ext", type=str, default="webp",
|
||||
help="the input file extension")
|
||||
parser.add_argument("--quality", type=int, default=95,
|
||||
help="the JPEG quality (0-100)")
|
||||
parser.add_argument("--delete_originals", action="store_true",
|
||||
help="whether to delete the original files after conversion")
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
directory = args.directory
|
||||
in_ext = args.in_ext
|
||||
out_ext = "jpg"
|
||||
quality = args.quality
|
||||
delete_originals = args.delete_originals
|
||||
|
||||
# Create the file pattern string using the input file extension
|
||||
file_pattern = f"*.{in_ext}"
|
||||
|
||||
# Get the list of files in the directory that match the file pattern
|
||||
files = glob.glob(os.path.join(directory, file_pattern))
|
||||
|
||||
# Iterate over the list of files
|
||||
for file in files:
|
||||
# Open the image file
|
||||
img = Image.open(file)
|
||||
|
||||
# Create a new file path with the output file extension
|
||||
new_path = Path(file).with_suffix(f".{out_ext}")
|
||||
|
||||
# Check if the output file already exists
|
||||
if new_path.exists():
|
||||
# Skip the conversion if the output file already exists
|
||||
print(f"Skipping {file} because {new_path} already exists")
|
||||
continue
|
||||
|
||||
# Save the image to the new file as high-quality JPEG
|
||||
img.save(new_path, quality=quality, optimize=True)
|
||||
|
||||
# Optionally, delete the original file
|
||||
if delete_originals:
|
||||
os.remove(file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
57
tools/convert_images_to_webp.py
Normal file
57
tools/convert_images_to_webp.py
Normal file
@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def main():
|
||||
# Define the command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("directory", type=str,
|
||||
help="the directory containing the images to be converted")
|
||||
parser.add_argument("--in_ext", type=str, default="webp",
|
||||
help="the input file extension")
|
||||
parser.add_argument("--delete_originals", action="store_true",
|
||||
help="whether to delete the original files after conversion")
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
directory = args.directory
|
||||
in_ext = args.in_ext
|
||||
delete_originals = args.delete_originals
|
||||
|
||||
# Set the output file extension to .webp
|
||||
out_ext = "webp"
|
||||
|
||||
# Create the file pattern string using the input file extension
|
||||
file_pattern = f"*.{in_ext}"
|
||||
|
||||
# Get the list of files in the directory that match the file pattern
|
||||
files = glob.glob(os.path.join(directory, file_pattern))
|
||||
|
||||
# Iterate over the list of files
|
||||
for file in files:
|
||||
# Open the image file
|
||||
img = Image.open(file)
|
||||
|
||||
# Create a new file path with the output file extension
|
||||
new_path = Path(file).with_suffix(f".{out_ext}")
|
||||
print(new_path)
|
||||
|
||||
# Check if the output file already exists
|
||||
if new_path.exists():
|
||||
# Skip the conversion if the output file already exists
|
||||
print(f"Skipping {file} because {new_path} already exists")
|
||||
continue
|
||||
|
||||
# Save the image to the new file as lossless
|
||||
img.save(new_path, lossless=True)
|
||||
|
||||
# Optionally, delete the original file
|
||||
if delete_originals:
|
||||
os.remove(file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1019
train_db_fixed.py
1019
train_db_fixed.py
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user