diff --git a/README.md b/README.md index 7ffce07..4b9a8d0 100644 --- a/README.md +++ b/README.md @@ -135,3 +135,8 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n - Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision) - Added support for saving learning state (--save_state, --resume) - Added support for logging (--logging_dir) +* 11/21 (v10): + - Added minimum/maximum resolution specification when using Aspect Ratio Bucketing (min_bucket_reso/max_bucket_reso option). + - Added extension specification for caption files (caption_extention). + - Added support for images with .webp extension. + - Added a function that allows captions to learning images and regularized images. \ No newline at end of file diff --git a/examples/caption.ps1 b/examples/caption.ps1 index 4673f87..e61f9b3 100644 --- a/examples/caption.ps1 +++ b/examples/caption.ps1 @@ -6,8 +6,9 @@ $folder = "D:\some\folder\location\" $file_pattern="*.*" $caption_text="some caption text" -$files = Get-ChildItem $folder$file_pattern -Include *.png,*.jpg,*.webp -File -foreach ($file in $files) -{ - New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text +$files = Get-ChildItem $folder$file_pattern -Include *.png, *.jpg, *.webp -File +foreach ($file in $files) { + if (-not(Test-Path -Path $folder\"$($file.BaseName).txt" -PathType Leaf)) { + New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text + } } \ No newline at end of file diff --git a/train_db_fixed.py b/train_db_fixed.py index 3cf3f5d..2037e78 100644 --- a/train_db_fixed.py +++ b/train_db_fixed.py @@ -5,6 +5,7 @@ # enable reg images in fine-tuning, add dataset_repeats option # v8: supports Diffusers 0.7.2 # v9: add bucketing option +# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth import time from torch.autograd.function import Function @@ -143,7 +144,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): ) # bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - def make_buckets_with_caching(self, enable_bucket, vae): + def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size): self.enable_bucket = enable_bucket cache_latents = vae is not None @@ -160,7 +161,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): # bucketingを用意する if enable_bucket: - bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height)) + bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height), min_size, max_size) else: # bucketはひとつだけ、すべての画像は同じ解像度 bucket_resos = [(self.width, self.height)] @@ -387,7 +388,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): latents_list.append(latents) # captionを処理する - if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする + if self.shuffle_caption: # captionのshuffleをする tokens = caption.strip().split(",") random.shuffle(tokens) caption = ",".join(tokens).strip() @@ -400,7 +401,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): else: # paddingする input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids - + example = {} example['loss_weights'] = torch.FloatTensor(loss_weights) example['input_ids'] = input_ids @@ -1133,44 +1134,55 @@ def train(args): set_seed(args.seed) # 学習データを用意する + def read_caption(img_path): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + args.caption_extention, base_name_face_det + args.caption_extention] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding='utf-8') as f: + caption = f.readlines()[0].strip() + break + return caption + def load_dreambooth_dir(dir): tokens = os.path.basename(dir).split('_') try: n_repeats = int(tokens[0]) except ValueError as e: - # print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}") - # raise e return 0, [] caption = '_'.join(tokens[1:]) print(f"found directory {n_repeats}_{caption}") - img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + glob.glob(os.path.join(dir, "*.webp")) - return n_repeats, [(ip, caption) for ip in img_paths] + img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ + glob.glob(os.path.join(dir, "*.webp")) + + # 画像ファイルごとにプロンプトを読み込み、もしあれば連結する + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path) + captions.append(caption + ("" if cap_for_img is None else cap_for_img)) + + return n_repeats, list(zip(img_paths, captions)) print("prepare train images.") train_img_path_captions = [] if fine_tuning: - img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \ + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) for img_path in tqdm(img_paths): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + '.txt', base_name + '.caption', base_name_face_det+'.txt', base_name_face_det+'.caption'] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding='utf-8') as f: - caption = f.readlines()[0].strip() - break - - assert caption is not None and len(caption) > 0, f"no caption / キャプションファイルが見つからないか、captionが空です: {cap_paths}" + caption = read_caption(img_path) + assert caption is not None and len( + caption) > 0, f"no caption for image. check caption_extention option / キャプションファイルが見つからないかcaptionが空です。caption_extentionオプションを確認してください: {img_path}" train_img_path_captions.append((img_path, caption)) @@ -1201,8 +1213,12 @@ def train(args): resolution = tuple([int(r) for r in args.resolution.split(',')]) if len(resolution) == 1: resolution = (resolution[0], resolution[0]) - assert len( - resolution) == 2, f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + assert len(resolution) == 2, \ + f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + + if args.enable_bucket: + assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください" + assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください" if args.face_crop_aug_range is not None: face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) @@ -1221,7 +1237,8 @@ def train(args): args.shuffle_caption, args.no_token_padding, args.debug_dataset) if args.debug_dataset: - train_dataset.make_buckets_with_caching(args.enable_bucket, None) # デバッグ用にcacheなしで作る + train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, + args.max_bucket_reso) # デバッグ用にcacheなしで作る print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("Escape for exit. / Escキーで中断、終了します") for example in train_dataset: @@ -1282,14 +1299,20 @@ def train(args): # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() with torch.no_grad(): - train_dataset.make_buckets_with_caching(args.enable_bucket, vae) + train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso) del vae if torch.cuda.is_available(): torch.cuda.empty_cache() else: - train_dataset.make_buckets_with_caching(args.enable_bucket, None) + train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso) vae.requires_grad_(False) + vae.eval() + + unet.requires_grad_(True) # 念のため追加 + text_encoder.requires_grad_(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -1343,7 +1366,7 @@ def train(args): print("running training / 学習開始") print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") - print(f" num examples / サンプル数: {train_dataset.num_train_images * 2}") + print(f" num examples / サンプル数: {train_dataset.num_train_images * (2 if train_dataset.enable_reg_images else 1)}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -1457,7 +1480,7 @@ def train(args): text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() - + if args.save_state: print("saving last state.") accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) @@ -1720,11 +1743,13 @@ def replace_unet_cross_attn_to_xformers(): k_in = self.to_k(context) v_in = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format + # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format del q_in, k_in, v_in out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる out = rearrange(out, 'b n h d -> b n (h d)', h=h) + # out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # diffusers 0.6.0 if type(self.to_out) is torch.nn.Sequential: @@ -1747,7 +1772,8 @@ if __name__ == '__main__': parser.add_argument("--fine_tuning", action="store_true", help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする") parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする") + help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") + parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption files / 読み込むcaptionファイルの拡張子") parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") parser.add_argument("--dataset_repeats", type=int, default=None, @@ -1758,8 +1784,7 @@ if __name__ == '__main__': help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, - help="saved state to resume training / 学習再開するモデルのstate") + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") parser.add_argument("--no_token_padding", action="store_true", help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") @@ -1785,6 +1810,8 @@ if __name__ == '__main__': help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") parser.add_argument("--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")