Update to latest dev code of kohya_s. WIP
This commit is contained in:
parent
2626214f8a
commit
2486af9903
@ -144,9 +144,10 @@ Then redo the installation instruction within the kohya_ss venv.
|
||||
## Change history
|
||||
|
||||
* 2023/02/04 (v20.6.1)
|
||||
- Add new LoRA resize GUI
|
||||
- ``--persistent_data_loader_workers`` option is added to ``fine_tune.py``, ``train_db.py`` and ``train_network.py``. This option may significantly reduce the waiting time between epochs. Thanks to hitomi!
|
||||
- ``--debug_dataset`` option is now working on non-Windows environment. Thanks to tsukimiya!
|
||||
- ``networks/resize_lora.py`` script is added. This can approximate the higher-rank (dim) LoRA model by a lower-rank LoRA model, e.g. 128 by 4. Thanks to mgz-dev!
|
||||
- ``networks/resize_lora.py`` script is added. This can approximate the higher-rank (dim) LoRA model by a lower-rank LoRA model, e.g. 128 to 4. Thanks to mgz-dev!
|
||||
- ``--help`` option shows usage.
|
||||
- Currently the metadata is not copied. This will be fixed in the near future.
|
||||
* 2023/02/03 (v20.6.0)
|
||||
|
@ -33,6 +33,7 @@ def train(args):
|
||||
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||
args.dataset_repeats, args.debug_dataset)
|
||||
train_dataset.make_buckets()
|
||||
@ -163,7 +164,7 @@ def train(args):
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
@ -200,6 +201,8 @@ def train(args):
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
@ -52,6 +52,10 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
|
||||
|
||||
|
||||
def main(args):
|
||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||
if args.bucket_reso_steps % 8 > 0:
|
||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
@ -77,32 +81,41 @@ def main(args):
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||
max_reso, args.min_bucket_reso, args.max_bucket_reso)
|
||||
bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
|
||||
args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
|
||||
if not args.bucket_no_upscale:
|
||||
bucket_manager.make_buckets()
|
||||
else:
|
||||
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
||||
|
||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
||||
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
||||
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
||||
img_ar_errors = []
|
||||
|
||||
def process_batch(is_last):
|
||||
for j in range(len(buckets_imgs)):
|
||||
bucket = buckets_imgs[j]
|
||||
for bucket in bucket_manager.buckets:
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
||||
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
|
||||
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
|
||||
f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
||||
|
||||
for (image_key, _, _), latent in zip(bucket, latents):
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
|
||||
np.savez(npz_file_name, latent)
|
||||
|
||||
# flip
|
||||
if args.flip_aug:
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
|
||||
for (image_key, _, _), latent in zip(bucket, latents):
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
|
||||
np.savez(npz_file_name, latent)
|
||||
else:
|
||||
# remove existing flipped npz
|
||||
for image_key, _ in bucket:
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
|
||||
if os.path.isfile(npz_file_name):
|
||||
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
||||
os.remove(npz_file_name)
|
||||
|
||||
bucket.clear()
|
||||
|
||||
@ -114,6 +127,7 @@ def main(args):
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
bucket_counts = {}
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
if data_entry[0] is None:
|
||||
continue
|
||||
@ -134,30 +148,25 @@ def main(args):
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
# 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
||||
aspect_ratio = image.width / image.height
|
||||
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||
bucket_id = np.abs(ar_errors).argmin()
|
||||
reso = bucket_resos[bucket_id]
|
||||
ar_error = ar_errors[bucket_id]
|
||||
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
||||
|
||||
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
|
||||
|
||||
# どのサイズにリサイズするか→トリミングする方向で
|
||||
if ar_error <= 0: # 横が長い→縦を合わせる
|
||||
scale = reso[1] / image.height
|
||||
else:
|
||||
scale = reso[0] / image.width
|
||||
|
||||
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
|
||||
|
||||
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
|
||||
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
|
||||
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
||||
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
||||
|
||||
if not args.bucket_no_upscale:
|
||||
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
||||
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
||||
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error resized size is small: {resized_size}, {reso}"
|
||||
|
||||
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
||||
if args.skip_existing:
|
||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
|
||||
@ -180,22 +189,24 @@ def main(args):
|
||||
# 画像をリサイズしてトリミングする
|
||||
# PILにinter_areaがないのでcv2で……
|
||||
image = np.array(image)
|
||||
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
if resized_size[0] > reso[0]:
|
||||
trim_size = resized_size[0] - reso[0]
|
||||
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||
elif resized_size[1] > reso[1]:
|
||||
|
||||
if resized_size[1] > reso[1]:
|
||||
trim_size = resized_size[1] - reso[1]
|
||||
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
|
||||
# # debug
|
||||
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
|
||||
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
|
||||
|
||||
# バッチへ追加
|
||||
buckets_imgs[bucket_id].append((image_key, reso, image))
|
||||
bucket_counts[bucket_id] += 1
|
||||
metadata[image_key]['train_resolution'] = reso
|
||||
bucket_manager.add_image(reso, (image_key, image))
|
||||
|
||||
# バッチを推論するか判定して推論する
|
||||
process_batch(False)
|
||||
@ -203,7 +214,10 @@ def main(args):
|
||||
# 残りを処理する
|
||||
process_batch(True)
|
||||
|
||||
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
||||
bucket_manager.sort()
|
||||
for i, reso in enumerate(bucket_manager.resos):
|
||||
count = bucket_counts.get(reso, 0)
|
||||
if count > 0:
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
@ -230,6 +244,10 @@ if __name__ == '__main__':
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
||||
parser.add_argument("--bucket_no_upscale", action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||
parser.add_argument("--full_path", action="store_true",
|
||||
|
@ -1163,15 +1163,14 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
|
||||
|
||||
resos = list(resos)
|
||||
resos.sort()
|
||||
|
||||
aspect_ratios = [w / h for w, h in resos]
|
||||
return resos, aspect_ratios
|
||||
return resos
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
resos, aspect_ratios = make_bucket_resolutions((512, 768))
|
||||
resos = make_bucket_resolutions((512, 768))
|
||||
print(len(resos))
|
||||
print(resos)
|
||||
aspect_ratios = [w / h for w, h in resos]
|
||||
print(aspect_ratios)
|
||||
|
||||
ars = set()
|
||||
|
@ -4,7 +4,7 @@ import argparse
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
from typing import NamedTuple
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
from accelerate import Accelerator
|
||||
from torch.autograd.function import Function
|
||||
import glob
|
||||
@ -55,16 +55,121 @@ class ImageInfo():
|
||||
self.caption: str = caption
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.image_size: tuple[int, int] = None
|
||||
self.bucket_reso: tuple[int, int] = None
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
self.latents: torch.Tensor = None
|
||||
self.latents_flipped: torch.Tensor = None
|
||||
self.latents_npz: str = None
|
||||
self.latents_npz_flipped: str = None
|
||||
|
||||
|
||||
class BucketManager():
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
self.no_upscale = no_upscale
|
||||
if max_reso is None:
|
||||
self.max_reso = None
|
||||
self.max_area = None
|
||||
else:
|
||||
self.max_reso = max_reso
|
||||
self.max_area = max_reso[0] * max_reso[1]
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
self.reso_steps = reso_steps
|
||||
|
||||
self.resos = []
|
||||
self.reso_to_id = {}
|
||||
self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
|
||||
|
||||
def add_image(self, reso, image):
|
||||
bucket_id = self.reso_to_id[reso]
|
||||
self.buckets[bucket_id].append(image)
|
||||
|
||||
def shuffle(self):
|
||||
for bucket in self.buckets:
|
||||
random.shuffle(bucket)
|
||||
|
||||
def sort(self):
|
||||
# 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
|
||||
sorted_resos = self.resos.copy()
|
||||
sorted_resos.sort()
|
||||
|
||||
sorted_buckets = []
|
||||
sorted_reso_to_id = {}
|
||||
for i, reso in enumerate(sorted_resos):
|
||||
bucket_id = self.reso_to_id[reso]
|
||||
sorted_buckets.append(self.buckets[bucket_id])
|
||||
sorted_reso_to_id[reso] = i
|
||||
|
||||
self.resos = sorted_resos
|
||||
self.buckets = sorted_buckets
|
||||
self.reso_to_id = sorted_reso_to_id
|
||||
|
||||
def make_buckets(self):
|
||||
resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
|
||||
self.set_predefined_resos(resos)
|
||||
|
||||
def set_predefined_resos(self, resos):
|
||||
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
||||
self.predefined_resos = resos.copy()
|
||||
self.predefined_resos_set = set(resos)
|
||||
self.predifined_aspect_ratios = np.array([w / h for w, h in resos])
|
||||
|
||||
def add_if_new_reso(self, reso):
|
||||
if reso not in self.reso_to_id:
|
||||
bucket_id = len(self.resos)
|
||||
self.reso_to_id[reso] = bucket_id
|
||||
self.resos.append(reso)
|
||||
self.buckets.append([])
|
||||
# print(reso, bucket_id, len(self.buckets))
|
||||
|
||||
def select_bucket(self, image_width, image_height):
|
||||
aspect_ratio = image_width / image_height
|
||||
if not self.no_upscale:
|
||||
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
|
||||
reso = (image_width, image_height)
|
||||
if reso in self.predefined_resos_set:
|
||||
pass
|
||||
else:
|
||||
ar_errors = self.predifined_aspect_ratios - aspect_ratio
|
||||
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
||||
reso = self.predefined_resos[predefined_bucket_id]
|
||||
|
||||
ar_reso = reso[0] / reso[1]
|
||||
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
|
||||
scale = reso[1] / image_height
|
||||
else:
|
||||
scale = reso[0] / image_width
|
||||
|
||||
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
|
||||
# print("use predef", image_width, image_height, reso, resized_size)
|
||||
else:
|
||||
if image_width * image_height > self.max_area:
|
||||
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
|
||||
resized_width = math.sqrt(self.max_area * aspect_ratio)
|
||||
resized_height = self.max_area / resized_width
|
||||
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
|
||||
|
||||
resized_size = (int(resized_width + .5), int(resized_height + .5))
|
||||
else:
|
||||
resized_size = (image_width, image_height) # リサイズは不要
|
||||
|
||||
# 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
|
||||
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
|
||||
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
|
||||
# print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
|
||||
|
||||
reso = (bucket_width, bucket_height)
|
||||
|
||||
self.add_if_new_reso(reso)
|
||||
|
||||
ar_error = (reso[0] / reso[1]) - aspect_ratio
|
||||
return reso, resized_size, ar_error
|
||||
|
||||
|
||||
class BucketBatchIndex(NamedTuple):
|
||||
bucket_index: int
|
||||
bucket_batch_size: int
|
||||
batch_index: int
|
||||
|
||||
|
||||
@ -85,11 +190,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.token_padding_disabled = False
|
||||
self.dataset_dirs_info = {}
|
||||
self.reg_dataset_dirs_info = {}
|
||||
self.tag_frequency = {}
|
||||
|
||||
self.enable_bucket = False
|
||||
self.bucket_manager: BucketManager = None # not initialized
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.tag_frequency = {}
|
||||
self.bucket_info = None
|
||||
self.bucket_reso_steps = None
|
||||
self.bucket_no_upscale = None
|
||||
self.bucket_info = None # for metadata
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
@ -113,7 +222,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
||||
|
||||
self.image_data: dict[str, ImageInfo] = {}
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
@ -215,66 +324,72 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
print("prepare dataset")
|
||||
|
||||
bucket_resos = self.bucket_resos
|
||||
bucket_aspect_ratios = np.array(self.bucket_aspect_ratios)
|
||||
|
||||
# bucketを作成する
|
||||
# bucketを作成し、画像をbucketに振り分ける
|
||||
if self.enable_bucket:
|
||||
if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
|
||||
self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
|
||||
self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
|
||||
if not self.bucket_no_upscale:
|
||||
self.bucket_manager.make_buckets()
|
||||
else:
|
||||
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
||||
|
||||
img_ar_errors = []
|
||||
for image_info in self.image_data.values():
|
||||
# bucketを決める
|
||||
image_width, image_height = image_info.image_size
|
||||
aspect_ratio = image_width / image_height
|
||||
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
|
||||
|
||||
bucket_id = np.abs(ar_errors).argmin()
|
||||
image_info.bucket_reso = bucket_resos[bucket_id]
|
||||
# print(image_info.image_key, image_info.bucket_reso)
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
|
||||
ar_error = ar_errors[bucket_id]
|
||||
img_ar_errors.append(ar_error)
|
||||
self.bucket_manager.sort()
|
||||
else:
|
||||
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
|
||||
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
|
||||
for image_info in self.image_data.values():
|
||||
image_info.bucket_reso = bucket_resos[0] # bucket_resos contains (width, height) only
|
||||
|
||||
# 画像をbucketに分割する
|
||||
self.buckets: list[str] = [[] for _ in range(len(bucket_resos))]
|
||||
reso_to_index = {}
|
||||
for i, reso in enumerate(bucket_resos):
|
||||
reso_to_index[reso] = i
|
||||
image_width, image_height = image_info.image_size
|
||||
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
|
||||
|
||||
for image_info in self.image_data.values():
|
||||
bucket_index = reso_to_index[image_info.bucket_reso]
|
||||
for _ in range(image_info.num_repeats):
|
||||
self.buckets[bucket_index].append(image_info.image_key)
|
||||
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
|
||||
|
||||
# bucket情報を表示、格納する
|
||||
if self.enable_bucket:
|
||||
self.bucket_info = {"buckets": {}}
|
||||
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
||||
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
|
||||
# only show bucket info if there is an actual image in it
|
||||
if len(img_keys) > 0:
|
||||
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
|
||||
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
|
||||
count = len(bucket)
|
||||
if count > 0:
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
||||
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
# 参照用indexを作る
|
||||
self.buckets_indices: list(BucketBatchIndex) = []
|
||||
for bucket_index, bucket in enumerate(self.buckets):
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List(BucketBatchIndex) = []
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
# bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# そのためバッチサイズを画像種類までに制限する
|
||||
# ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# TODO 正則化画像をepochまたがりで利用する仕組み
|
||||
num_of_image_types = len(set(bucket))
|
||||
bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
for batch_index in range(batch_count):
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, batch_index))
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
|
||||
self.shuffle_buckets()
|
||||
self._length = len(self.buckets_indices)
|
||||
|
||||
def shuffle_buckets(self):
|
||||
random.shuffle(self.buckets_indices)
|
||||
for bucket in self.buckets:
|
||||
random.shuffle(bucket)
|
||||
self.bucket_manager.shuffle()
|
||||
|
||||
def load_image(self, image_path):
|
||||
image = Image.open(image_path)
|
||||
@ -283,28 +398,30 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
|
||||
def resize_and_trim(self, image, reso):
|
||||
def trim_and_resize_if_required(self, image, reso, resized_size):
|
||||
image_height, image_width = image.shape[0:2]
|
||||
ar_img = image_width / image_height
|
||||
ar_reso = reso[0] / reso[1]
|
||||
if ar_img > ar_reso: # 横が長い→縦を合わせる
|
||||
scale = reso[1] / image_height
|
||||
else:
|
||||
scale = reso[0] / image_width
|
||||
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
# リサイズする
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
if resized_size[0] > reso[0]:
|
||||
trim_size = resized_size[0] - reso[0]
|
||||
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||
elif resized_size[1] > reso[1]:
|
||||
trim_size = resized_size[1] - reso[1]
|
||||
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
|
||||
f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
|
||||
image_height, image_width = image.shape[0:2]
|
||||
if image_width > reso[0]:
|
||||
trim_size = image_width - reso[0]
|
||||
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
||||
# print("w", trim_size, p)
|
||||
image = image[:, p:p + reso[0]]
|
||||
if image_height > reso[1]:
|
||||
trim_size = image_height - reso[1]
|
||||
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
||||
# print("h", trim_size, p)
|
||||
image = image[p:p + reso[1]]
|
||||
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
return image
|
||||
|
||||
def cache_latents(self, vae):
|
||||
# TODO ここを高速化したい
|
||||
print("caching latents.")
|
||||
for info in tqdm(self.image_data.values()):
|
||||
if info.latents_npz is not None:
|
||||
@ -316,7 +433,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
|
||||
image = self.load_image(info.absolute_path)
|
||||
image = self.resize_and_trim(image, info.bucket_reso)
|
||||
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
||||
|
||||
img_tensor = self.image_transforms(image)
|
||||
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
||||
@ -406,8 +523,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if index == 0:
|
||||
self.shuffle_buckets()
|
||||
|
||||
bucket = self.buckets[self.buckets_indices[index].bucket_index]
|
||||
image_index = self.buckets_indices[index].batch_index * self.batch_size
|
||||
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
||||
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
||||
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
||||
|
||||
loss_weights = []
|
||||
captions = []
|
||||
@ -415,7 +533,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
latents_list = []
|
||||
images = []
|
||||
|
||||
for image_key in bucket[image_index:image_index + self.batch_size]:
|
||||
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
|
||||
@ -433,7 +551,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img = self.resize_and_trim(img, image_info.bucket_reso)
|
||||
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
||||
@ -490,7 +608,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
||||
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
||||
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
||||
|
||||
@ -505,13 +623,15 @@ class DreamBoothDataset(BaseDataset):
|
||||
if self.enable_bucket:
|
||||
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
self.bucket_resos = [(self.width, self.height)]
|
||||
self.bucket_aspect_ratios = [self.width / self.height]
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path):
|
||||
# captionの候補ファイル名を作る
|
||||
@ -582,7 +702,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_reg_images = 0
|
||||
if reg_data_dir:
|
||||
print("prepare reg images.")
|
||||
reg_infos: list[ImageInfo] = []
|
||||
reg_infos: List[ImageInfo] = []
|
||||
|
||||
reg_dirs = os.listdir(reg_data_dir)
|
||||
for dir in reg_dirs:
|
||||
@ -621,7 +741,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
|
||||
class FineTuningDataset(BaseDataset):
|
||||
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
||||
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
||||
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
||||
|
||||
@ -660,7 +780,7 @@ class FineTuningDataset(BaseDataset):
|
||||
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get('train_resolution')
|
||||
|
||||
if not self.color_aug:
|
||||
if not self.color_aug and not self.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
||||
|
||||
@ -672,7 +792,8 @@ class FineTuningDataset(BaseDataset):
|
||||
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||
|
||||
# check existence of all npz files
|
||||
if not self.color_aug:
|
||||
use_npz_latents = not (self.color_aug or self.random_crop)
|
||||
if use_npz_latents:
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
for image_info in self.image_data.values():
|
||||
@ -687,13 +808,15 @@ class FineTuningDataset(BaseDataset):
|
||||
break
|
||||
|
||||
if not npz_any:
|
||||
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
|
||||
use_npz_latents = False
|
||||
print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
||||
elif not npz_all:
|
||||
use_npz_latents = False
|
||||
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
||||
if self.flip_aug:
|
||||
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
# else:
|
||||
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
|
||||
# check min/max bucket size
|
||||
sizes = set()
|
||||
@ -707,30 +830,34 @@ class FineTuningDataset(BaseDataset):
|
||||
resos.add(tuple(image_info.image_size))
|
||||
|
||||
if sizes is None:
|
||||
if use_npz_latents:
|
||||
use_npz_latents = False
|
||||
print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
|
||||
|
||||
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
else:
|
||||
self.bucket_resos = [(self.width, self.height)]
|
||||
self.bucket_aspect_ratios = [self.width / self.height]
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
if not enable_bucket:
|
||||
print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
||||
print("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
||||
self.enable_bucket = True
|
||||
self.bucket_resos = list(resos)
|
||||
self.bucket_resos.sort()
|
||||
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
|
||||
|
||||
self.min_bucket_reso = min([min(reso) for reso in resos])
|
||||
self.max_bucket_reso = max([max(reso) for reso in resos])
|
||||
assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
||||
|
||||
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
||||
self.bucket_manager = BucketManager(False, None, None, None, None)
|
||||
self.bucket_manager.set_predefined_resos(resos)
|
||||
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
|
||||
def image_key_to_npz_file(self, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
@ -760,7 +887,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
k = 0
|
||||
for example in train_dataset:
|
||||
for i, example in enumerate(train_dataset):
|
||||
if example['latents'] is not None:
|
||||
print("sample has latents from npz file")
|
||||
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
||||
@ -778,7 +905,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
cv2.destroyAllWindows()
|
||||
if k == 27:
|
||||
break
|
||||
if k == 27 or example['images'] is None:
|
||||
if k == 27 or (example['images'] is None and i >= 8):
|
||||
break
|
||||
|
||||
|
||||
@ -1254,6 +1381,10 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
||||
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
||||
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
||||
parser.add_argument("--bucket_no_upscale", action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth dataset
|
||||
@ -1285,6 +1416,7 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
|
||||
if args.cache_latents:
|
||||
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
||||
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
||||
|
||||
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
||||
if args.resolution is not None:
|
||||
@ -1296,14 +1428,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
|
||||
if args.face_crop_aug_range is not None:
|
||||
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
|
||||
assert len(args.face_crop_aug_range) == 2, \
|
||||
assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
|
||||
f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
|
||||
else:
|
||||
args.face_crop_aug_range = None
|
||||
|
||||
if support_metadata:
|
||||
if args.in_json is not None and args.color_aug:
|
||||
print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます")
|
||||
if args.in_json is not None and (args.color_aug or args.random_crop):
|
||||
print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
|
||||
|
||||
|
||||
def load_tokenizer(args: argparse.Namespace):
|
||||
|
@ -358,6 +358,9 @@ def train_model(
|
||||
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
||||
|
||||
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
|
||||
|
||||
run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop'
|
||||
|
||||
if v2:
|
||||
run_cmd += ' --v2'
|
||||
if v_parameterization:
|
||||
|
@ -35,8 +35,9 @@ def train(args):
|
||||
|
||||
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
||||
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||
if args.no_token_padding:
|
||||
train_dataset.disable_token_padding()
|
||||
train_dataset.make_buckets()
|
||||
|
@ -120,13 +120,16 @@ def train(args):
|
||||
print("Use DreamBooth method.")
|
||||
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
||||
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
||||
args.random_crop, args.debug_dataset)
|
||||
else:
|
||||
print("Train with captions.")
|
||||
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||
args.dataset_repeats, args.debug_dataset)
|
||||
train_dataset.make_buckets()
|
||||
|
@ -143,13 +143,15 @@ def train(args):
|
||||
print("Use DreamBooth method.")
|
||||
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
||||
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||
else:
|
||||
print("Train with captions.")
|
||||
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||
args.dataset_repeats, args.debug_dataset)
|
||||
|
||||
@ -217,7 +219,7 @@ def train(args):
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
@ -312,7 +314,8 @@ def train(args):
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) # weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
||||
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
|
Loading…
Reference in New Issue
Block a user