From 09d3a72cd8041f9de313e439cac52727c9e4dc99 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 7 Feb 2023 20:58:35 -0500 Subject: [PATCH] Adding support for caption dropout --- dreambooth_gui.py | 7 +++++ fine_tune.py | 9 +++++- finetune_gui.py | 7 +++++ library/common_gui.py | 16 ++++++++++ library/train_util.py | 44 ++++++++++++++++++++++----- lora_gui.py | 13 ++++++-- networks/lora.py | 3 +- textual_inversion_gui.py | 7 +++++ tools/resize_images_to_resolutions.py | 8 +++-- train_db.py | 9 +++++- train_network.py | 17 ++++++++--- train_textual_inversion.py | 2 +- 12 files changed, 121 insertions(+), 21 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index cdcb85b..2e9cfdb 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -88,6 +88,7 @@ def save_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -177,6 +178,7 @@ def open_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -250,6 +252,7 @@ def train_model( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -416,6 +419,8 @@ def train_model( bucket_no_upscale=bucket_no_upscale, random_crop=random_crop, bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, ) print(run_cmd) @@ -627,6 +632,7 @@ def dreambooth_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -695,6 +701,7 @@ def dreambooth_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ] button_open_config.click( diff --git a/fine_tune.py b/fine_tune.py index 6a95886..e743a34 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -36,6 +36,10 @@ def train(args): args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.dataset_repeats, args.debug_dataset) + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) + train_dataset.make_buckets() if args.debug_dataset: @@ -226,6 +230,9 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + + train_dataset.epoch_current = epoch + 1 + for m in training_models: m.train() @@ -332,7 +339,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True) + train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_sd_saving_arguments(parser) diff --git a/finetune_gui.py b/finetune_gui.py index 80c887f..f5aad9d 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -84,6 +84,7 @@ def save_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -179,6 +180,7 @@ def open_config_file( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -259,6 +261,7 @@ def train_model( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # create caption json file if generate_caption_database: @@ -405,6 +408,8 @@ def train_model( bucket_no_upscale=bucket_no_upscale, random_crop=random_crop, bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, ) print(run_cmd) @@ -614,6 +619,7 @@ def finetune_tab(): bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -678,6 +684,7 @@ def finetune_tab(): bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ] button_run.click(train_model, inputs=settings_list) diff --git a/library/common_gui.py b/library/common_gui.py index c93b04e..a78532d 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -563,6 +563,15 @@ def gradio_advanced_training(): random_crop = gr.Checkbox( label='Random crop instead of center crop', value=False ) + with gr.Row(): + caption_dropout_every_n_epochs = gr.Number( + label="Dropout caption every n epochs", + value=0 + ) + caption_dropout_rate = gr.Number( + label="Rate of caption dropout", + value=0 + ) with gr.Row(): save_state = gr.Checkbox(label='Save training state', value=False) resume = gr.Textbox( @@ -599,6 +608,7 @@ def gradio_advanced_training(): bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) @@ -622,6 +632,12 @@ def run_cmd_advanced_training(**kwargs): f' --keep_tokens="{kwargs.get("keep_tokens", "")}"' if int(kwargs.get('keep_tokens', 0)) > 0 else '', + f' --caption_dropout_every_n_epochs="{kwargs.get("caption_dropout_every_n_epochs", "")}"' + if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 + else '', + f' --caption_dropout_rate="{kwargs.get("caption_dropout_rate", "")}"' + if float(kwargs.get('caption_dropout_rate', 0)) > 0 + else '', f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' if int(kwargs.get('bucket_reso_steps', 64)) >= 1 diff --git a/library/train_util.py b/library/train_util.py index 379b0b8..eb1ec12 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -113,7 +113,7 @@ class BucketManager(): # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく self.predefined_resos = resos.copy() self.predefined_resos_set = set(resos) - self.predifined_aspect_ratios = np.array([w / h for w, h in resos]) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) def add_if_new_reso(self, reso): if reso not in self.reso_to_id: @@ -135,7 +135,7 @@ class BucketManager(): if reso in self.predefined_resos_set: pass else: - ar_errors = self.predifined_aspect_ratios - aspect_ratio + ar_errors = self.predefined_aspect_ratios - aspect_ratio predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの reso = self.predefined_resos[predefined_bucket_id] @@ -223,6 +223,11 @@ class BaseDataset(torch.utils.data.Dataset): self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + # TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう + self.epoch_current: int = int(0) + self.dropout_rate: float = 0 + self.dropout_every_n_epochs: int = None + # augmentation flip_p = 0.5 if flip_aug else 0.0 if color_aug: @@ -247,6 +252,12 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} + def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs): + # 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく + # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく) + self.dropout_rate = dropout_rate + self.dropout_every_n_epochs = dropout_every_n_epochs + def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) self.tag_frequency[dir_name] = frequency_for_dir @@ -265,7 +276,7 @@ class BaseDataset(torch.utils.data.Dataset): def process_caption(self, caption): if self.shuffle_caption: - tokens = caption.strip().split(",") + tokens = [t.strip() for t in caption.strip().split(",")] if self.shuffle_keep_tokens is None: random.shuffle(tokens) else: @@ -274,7 +285,7 @@ class BaseDataset(torch.utils.data.Dataset): tokens = tokens[self.shuffle_keep_tokens:] random.shuffle(tokens) tokens = keep_tokens + tokens - caption = ",".join(tokens).strip() + caption = ", ".join(tokens) for str_from, str_to in self.replacements.items(): if str_from == "": @@ -598,7 +609,18 @@ class BaseDataset(torch.utils.data.Dataset): images.append(image) latents_list.append(latents) - caption = self.process_caption(image_info.caption) + # dropoutの決定 + is_drop_out = False + if self.dropout_rate > 0 and random.random() < self.dropout_rate: + is_drop_out = True + if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0: + is_drop_out = True + + if is_drop_out: + caption = "" + print(f"Drop caption out: {self.process_caption(image_info.caption)}") + else: + caption = self.process_caption(image_info.caption) captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future input_ids_list.append(self.get_input_ids(caption)) @@ -1377,7 +1399,7 @@ def verify_training_args(args: argparse.Namespace): print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") -def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool): +def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool): # dataset common parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--shuffle_caption", action="store_true", @@ -1408,6 +1430,14 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b parser.add_argument("--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに + parser.add_argument("--caption_dropout_rate", type=float, default=0, + help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") + parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None, + help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") + if support_dreambooth: # DreamBooth dataset parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") @@ -1718,4 +1748,4 @@ class ImageLoadingDataset(torch.utils.data.Dataset): return (tensor_pil, img_path) -# endregion \ No newline at end of file +# endregion diff --git a/lora_gui.py b/lora_gui.py index 04c0cc1..d48fb5a 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -99,6 +99,7 @@ def save_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -195,6 +196,7 @@ def open_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -275,7 +277,8 @@ def train_model( bucket_no_upscale, random_crop, bucket_reso_steps, -): + caption_dropout_every_n_epochs, caption_dropout_rate, +): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') return @@ -380,7 +383,7 @@ def train_model( run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"' - run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop' + # run_cmd += f' --caption_dropout_rate="0.1" --caption_dropout_every_n_epochs=1' # --random_crop' if v2: run_cmd += ' --v2' @@ -440,7 +443,7 @@ def train_model( else: run_cmd += f' --lr_scheduler_num_cycles="{epoch}"' if not lr_scheduler_power == '': - run_cmd += f' --output_name="{lr_scheduler_power}"' + run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"' run_cmd += run_cmd_training( learning_rate=learning_rate, @@ -476,6 +479,8 @@ def train_model( bucket_no_upscale=bucket_no_upscale, random_crop=random_crop, bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, ) print(run_cmd) @@ -725,6 +730,7 @@ def lora_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -805,6 +811,7 @@ def lora_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ] button_open_config.click( diff --git a/networks/lora.py b/networks/lora.py index 174feda..a1f38c1 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,6 +5,7 @@ import math import os +from typing import List import torch from library import train_util @@ -98,7 +99,7 @@ class LoRANetwork(torch.nn.Module): self.alpha = alpha # create module instances - def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: + def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: loras = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index b34ca6d..d7b86ef 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -94,6 +94,7 @@ def save_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -193,6 +194,7 @@ def open_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -272,6 +274,7 @@ def train_model( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -453,6 +456,8 @@ def train_model( bucket_no_upscale=bucket_no_upscale, random_crop=random_crop, bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, ) run_cmd += f' --token_string="{token_string}"' run_cmd += f' --init_word="{init_word}"' @@ -709,6 +714,7 @@ def ti_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -783,6 +789,7 @@ def ti_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ] button_open_config.click( diff --git a/tools/resize_images_to_resolutions.py b/tools/resize_images_to_resolutions.py index 5492f1c..e55b285 100644 --- a/tools/resize_images_to_resolutions.py +++ b/tools/resize_images_to_resolutions.py @@ -4,7 +4,7 @@ import argparse import shutil import math -def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2): +def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=1): # Split the max_resolution string by "," and strip any whitespaces max_resolutions = [res.strip() for res in max_resolution.split(',')] @@ -57,7 +57,11 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Split filename into base and extension base, ext = os.path.splitext(filename) new_filename = base + '+' + max_resolution + '.jpg' - + + # copy caption file with right name if one exist + if os.path.exists(os.path.join(src_img_folder, base + '.txt')): + shutil.copy(os.path.join(src_img_folder, base + '.txt'), os.path.join(dst_img_folder, new_filename + '.txt')) + # Save resized image in dst_img_folder cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") diff --git a/train_db.py b/train_db.py index d1bbc07..51f5038 100644 --- a/train_db.py +++ b/train_db.py @@ -38,8 +38,13 @@ def train(args): args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps, args.bucket_no_upscale, args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + if args.no_token_padding: train_dataset.disable_token_padding() + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) + train_dataset.make_buckets() if args.debug_dataset: @@ -204,6 +209,8 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.epoch_current = epoch + 1 + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() # train==True is required to enable gradient_checkpointing @@ -327,7 +334,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, False) + train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) train_util.add_sd_saving_arguments(parser) diff --git a/train_network.py b/train_network.py index 3e8f4e7..f3ca417 100644 --- a/train_network.py +++ b/train_network.py @@ -120,18 +120,22 @@ def train(args): print("Use DreamBooth method.") train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.bucket_reso_steps, args.bucket_no_upscale, + args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) else: print("Train with captions.") train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, + args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.dataset_repeats, args.debug_dataset) + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) + train_dataset.make_buckets() if args.debug_dataset: @@ -376,6 +380,9 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + + train_dataset.epoch_current = epoch + 1 + metadata["ss_epoch"] = str(epoch+1) network.on_epoch_start(text_encoder, unet) @@ -509,7 +516,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True) + train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 7a8370c..d3e558a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -478,7 +478,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True) + train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],