diff --git a/README.md b/README.md index 3425f0f..a85f8fe 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,14 @@ Then redo the installation instruction within the kohya_ss venv. ## Change history +* 2023/02/09 (v20.7.1) + - Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource! + - ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout). + - ``--caption_dropout_every_n_epochs`` option specifies how many epochs to drop captions. If ``3`` is specified, in epoch 3, 6, 9 ..., images are trained with all captions empty. Default is None (no dropout). + - ``--caption_tag_dropout_rate`` option specified the dropout rate for tags (comma separated tokens) (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the tag is removed from the caption. If ``--keep_tokens`` option is set, these tokens (tags) are not dropped. Default is 0 (no droupout). + - The bulk image downsampling script is added. Documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) (in Jpanaese). Thanks to bmaltais! + - Typo check is added. Thanks to shirayu! + - Add option to autolaunch the GUI in a browser and set the server_port. USe either `gui.ps1 --inbrowser --server_port 3456`or `gui.cmd -inbrowser -server_port 3456` * 2023/02/06 (v20.7.0) - ``--bucket_reso_steps`` and ``--bucket_no_upscale`` options are added to training scripts (fine tuning, DreamBooth, LoRA and Textual Inversion) and ``prepare_buckets_latents.py``. - ``--bucket_reso_steps`` takes the steps for buckets in aspect ratio bucketing. Default is 64, same as before. diff --git a/dreambooth_gui.py b/dreambooth_gui.py index cdcb85b..2e9cfdb 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -88,6 +88,7 @@ def save_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -177,6 +178,7 @@ def open_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -250,6 +252,7 @@ def train_model( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -416,6 +419,8 @@ def train_model( bucket_no_upscale=bucket_no_upscale, random_crop=random_crop, bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, ) print(run_cmd) @@ -627,6 +632,7 @@ def dreambooth_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -695,6 +701,7 @@ def dreambooth_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ] button_open_config.click( diff --git a/fine_tune.py b/fine_tune.py index 6a95886..5292153 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -36,6 +36,10 @@ def train(args): args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.dataset_repeats, args.debug_dataset) + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) + train_dataset.make_buckets() if args.debug_dataset: @@ -226,6 +230,8 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.set_current_epoch(epoch + 1) + for m in training_models: m.train() @@ -332,7 +338,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True) + train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_sd_saving_arguments(parser) diff --git a/finetune_gui.py b/finetune_gui.py index 80c887f..f5aad9d 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -84,6 +84,7 @@ def save_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -179,6 +180,7 @@ def open_config_file( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -259,6 +261,7 @@ def train_model( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # create caption json file if generate_caption_database: @@ -405,6 +408,8 @@ def train_model( bucket_no_upscale=bucket_no_upscale, random_crop=random_crop, bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, ) print(run_cmd) @@ -614,6 +619,7 @@ def finetune_tab(): bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -678,6 +684,7 @@ def finetune_tab(): bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ] button_run.click(train_model, inputs=settings_list) diff --git a/gui.bat b/gui.bat index fbf5101..978abf2 100644 --- a/gui.bat +++ b/gui.bat @@ -1,10 +1,23 @@ @echo off -set VENV_DIR=.\venv -set PYTHON=python +REM Use this batch file with the following options: +REM -inbrowser - To launch the program in the browser +REM -server_port [port number] - To specify the server port -call %VENV_DIR%\Scripts\activate.bat +set inbrowserOption= +set serverPortOption= -%PYTHON% kohya_gui.py +if "%1" == "-server_port" ( + set serverPortOption=--server_port %2 + if "%3" == "-inbrowser" ( + set inbrowserOption=--inbrowser + ) +) else if "%1" == "-inbrowser" ( + set inbrowserOption=--inbrowser + if "%2" == "-server_port" ( + set serverPortOption=--server_port %3 + ) +) -pause \ No newline at end of file +call .\venv\Scripts\activate.bat +python.exe kohya_gui.py %inbrowserOption% %serverPortOption% diff --git a/gui.ps1 b/gui.ps1 index 4f799a1..e09df2e 100644 --- a/gui.ps1 +++ b/gui.ps1 @@ -1,2 +1,11 @@ +# Example command: .\gui.ps1 -server_port 8000 -inbrowser + +param([string]$username="", [string]$password="", [switch]$inbrowser, [int]$server_port) .\venv\Scripts\activate -python.exe kohya_gui.py \ No newline at end of file + +if ($server_port -le 0 -and $inbrowser -eq $false) { + Write-Host "Error: You must provide either the --server_port or --inbrowser argument." + exit 1 +} + +python.exe kohya_gui.py --username $username --password $password --server_port $server_port --inbrowser \ No newline at end of file diff --git a/kohya_gui.py b/kohya_gui.py index fa51fd6..b44c652 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -10,7 +10,7 @@ from library.merge_lora_gui import gradio_merge_lora_tab from lora_gui import lora_tab -def UI(username, password): +def UI(username, password, inbrowser, server_port): css = '' @@ -47,11 +47,13 @@ def UI(username, password): gradio_merge_lora_tab() # Show the interface - if not username == '': - interface.launch(auth=(username, password)) - else: - interface.launch() - + kwargs = {} + if username: + kwargs["auth"] = (username, password) + if server_port > 0: + kwargs["server_port"] = server_port + kwargs["inbrowser"] = inbrowser + interface.launch(**kwargs) if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) @@ -62,7 +64,11 @@ if __name__ == '__main__': parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) + parser.add_argument( + '--server_port', type=int, default=0, help='Port to run the server listener on' + ) + parser.add_argument("--inbrowser", action="store_true", help="Open in browser") args = parser.parse_args() - UI(username=args.username, password=args.password) + UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port) diff --git a/library/common_gui.py b/library/common_gui.py index c93b04e..a78532d 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -563,6 +563,15 @@ def gradio_advanced_training(): random_crop = gr.Checkbox( label='Random crop instead of center crop', value=False ) + with gr.Row(): + caption_dropout_every_n_epochs = gr.Number( + label="Dropout caption every n epochs", + value=0 + ) + caption_dropout_rate = gr.Number( + label="Rate of caption dropout", + value=0 + ) with gr.Row(): save_state = gr.Checkbox(label='Save training state', value=False) resume = gr.Textbox( @@ -599,6 +608,7 @@ def gradio_advanced_training(): bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) @@ -622,6 +632,12 @@ def run_cmd_advanced_training(**kwargs): f' --keep_tokens="{kwargs.get("keep_tokens", "")}"' if int(kwargs.get('keep_tokens', 0)) > 0 else '', + f' --caption_dropout_every_n_epochs="{kwargs.get("caption_dropout_every_n_epochs", "")}"' + if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 + else '', + f' --caption_dropout_rate="{kwargs.get("caption_dropout_rate", "")}"' + if float(kwargs.get('caption_dropout_rate', 0)) > 0 + else '', f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' if int(kwargs.get('bucket_reso_steps', 64)) >= 1 diff --git a/library/train_util.py b/library/train_util.py index 379b0b8..df6e24e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -113,7 +113,7 @@ class BucketManager(): # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく self.predefined_resos = resos.copy() self.predefined_resos_set = set(resos) - self.predifined_aspect_ratios = np.array([w / h for w, h in resos]) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) def add_if_new_reso(self, reso): if reso not in self.reso_to_id: @@ -135,7 +135,7 @@ class BucketManager(): if reso in self.predefined_resos_set: pass else: - ar_errors = self.predifined_aspect_ratios - aspect_ratio + ar_errors = self.predefined_aspect_ratios - aspect_ratio predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの reso = self.predefined_resos[predefined_bucket_id] @@ -223,6 +223,10 @@ class BaseDataset(torch.utils.data.Dataset): self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ + self.dropout_rate: float = 0 + self.dropout_every_n_epochs: int = None + # augmentation flip_p = 0.5 if flip_aug else 0.0 if color_aug: @@ -247,6 +251,15 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} + def set_current_epoch(self, epoch): + self.current_epoch = epoch + + def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate): + # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく) + self.dropout_rate = dropout_rate + self.dropout_every_n_epochs = dropout_every_n_epochs + self.tag_dropout_rate = tag_dropout_rate + def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) self.tag_frequency[dir_name] = frequency_for_dir @@ -264,27 +277,47 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements[str_from] = str_to def process_caption(self, caption): - if self.shuffle_caption: - tokens = caption.strip().split(",") - if self.shuffle_keep_tokens is None: - random.shuffle(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[:self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens:] - random.shuffle(tokens) - tokens = keep_tokens + tokens - caption = ",".join(tokens).strip() + # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い + is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate + is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0 - for str_from, str_to in self.replacements.items(): - if str_from == "": - # replace all - if type(str_to) == list: - caption = random.choice(str_to) + if is_drop_out: + caption = "" + else: + if self.shuffle_caption: + def dropout_tags(tokens): + if self.tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= self.tag_dropout_rate: + l.append(token) + return l + + tokens = [t.strip() for t in caption.strip().split(",")] + if self.shuffle_keep_tokens is None: + random.shuffle(tokens) + tokens = dropout_tags(tokens) else: - caption = str_to - else: - caption = caption.replace(str_from, str_to) + if len(tokens) > self.shuffle_keep_tokens: + keep_tokens = tokens[:self.shuffle_keep_tokens] + tokens = tokens[self.shuffle_keep_tokens:] + random.shuffle(tokens) + tokens = dropout_tags(tokens) + + tokens = keep_tokens + tokens + caption = ", ".join(tokens) + + # textual inversion対応 + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) + else: + caption = str_to + else: + caption = caption.replace(str_from, str_to) return caption @@ -907,6 +940,8 @@ class FineTuningDataset(BaseDataset): def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("Escape for exit. / Escキーで中断、終了します") + + train_dataset.set_current_epoch(1) k = 0 for i, example in enumerate(train_dataset): if example['latents'] is not None: @@ -1377,7 +1412,7 @@ def verify_training_args(args: argparse.Namespace): print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") -def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool): +def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool): # dataset common parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--shuffle_caption", action="store_true", @@ -1408,6 +1443,16 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b parser.add_argument("--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに + parser.add_argument("--caption_dropout_rate", type=float, default=0, + help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") + parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None, + help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") + parser.add_argument("--caption_tag_dropout_rate", type=float, default=0, + help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合") + if support_dreambooth: # DreamBooth dataset parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") @@ -1718,4 +1763,4 @@ class ImageLoadingDataset(torch.utils.data.Dataset): return (tensor_pil, img_path) -# endregion \ No newline at end of file +# endregion diff --git a/lora_gui.py b/lora_gui.py index 04c0cc1..d48fb5a 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -99,6 +99,7 @@ def save_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -195,6 +196,7 @@ def open_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -275,7 +277,8 @@ def train_model( bucket_no_upscale, random_crop, bucket_reso_steps, -): + caption_dropout_every_n_epochs, caption_dropout_rate, +): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') return @@ -380,7 +383,7 @@ def train_model( run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"' - run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop' + # run_cmd += f' --caption_dropout_rate="0.1" --caption_dropout_every_n_epochs=1' # --random_crop' if v2: run_cmd += ' --v2' @@ -440,7 +443,7 @@ def train_model( else: run_cmd += f' --lr_scheduler_num_cycles="{epoch}"' if not lr_scheduler_power == '': - run_cmd += f' --output_name="{lr_scheduler_power}"' + run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"' run_cmd += run_cmd_training( learning_rate=learning_rate, @@ -476,6 +479,8 @@ def train_model( bucket_no_upscale=bucket_no_upscale, random_crop=random_crop, bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, ) print(run_cmd) @@ -725,6 +730,7 @@ def lora_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -805,6 +811,7 @@ def lora_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ] button_open_config.click( diff --git a/networks/lora.py b/networks/lora.py index 174feda..a1f38c1 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,6 +5,7 @@ import math import os +from typing import List import torch from library import train_util @@ -98,7 +99,7 @@ class LoRANetwork(torch.nn.Module): self.alpha = alpha # create module instances - def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: + def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: loras = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: diff --git a/requirements.txt b/requirements.txt index eeb0bdc..a8bcefb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,24 +1,26 @@ accelerate==0.15.0 transformers==4.26.0 -ftfy -albumentations -opencv-python -einops +ftfy==6.1.1 +albumentations==1.3.0 +opencv-python==4.7.0.68 +einops==0.6.0 diffusers[torch]==0.10.2 -pytorch_lightning +pytorch-lightning==1.9.0 bitsandbytes==0.35.0 -tensorboard +tensorboard==2.10.1 safetensors==0.2.6 gradio==3.16.2 -altair -easygui -tk +altair==4.2.2 +easygui==0.98.3 +tk==0.1.0 # for BLIP captioning -requests -timm -fairscale +requests==2.28.2 +timm==0.6.12 +fairscale==0.4.13 # for WD14 captioning -tensorflow<2.11 -huggingface-hub +# tensorflow<2.11 +tensorflow==2.10.1 +huggingface-hub==0.12.0 +xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl # for kohya_ss library . \ No newline at end of file diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index b34ca6d..d7b86ef 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -94,6 +94,7 @@ def save_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -193,6 +194,7 @@ def open_configuration( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -272,6 +274,7 @@ def train_model( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -453,6 +456,8 @@ def train_model( bucket_no_upscale=bucket_no_upscale, random_crop=random_crop, bucket_reso_steps=bucket_reso_steps, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, ) run_cmd += f' --token_string="{token_string}"' run_cmd += f' --init_word="{init_word}"' @@ -709,6 +714,7 @@ def ti_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -783,6 +789,7 @@ def ti_tab( bucket_no_upscale, random_crop, bucket_reso_steps, + caption_dropout_every_n_epochs, caption_dropout_rate, ] button_open_config.click( diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py new file mode 100644 index 0000000..0876a4d --- /dev/null +++ b/tools/resize_images_to_resolution.py @@ -0,0 +1,113 @@ +import glob +import os +import cv2 +import argparse +import shutil +import math + + +def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): + # Split the max_resolution string by "," and strip any whitespaces + max_resolutions = [res.strip() for res in max_resolution.split(',')] + + # # Calculate max_pixels from max_resolution string + # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) + + # Create destination folder if it does not exist + if not os.path.exists(dst_img_folder): + os.makedirs(dst_img_folder) + + # Select interpolation method + if interpolation == 'lanczos4': + cv2_interpolation = cv2.INTER_LANCZOS4 + elif interpolation == 'cubic': + cv2_interpolation = cv2.INTER_CUBIC + else: + cv2_interpolation = cv2.INTER_AREA + + # Iterate through all files in src_img_folder + img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py + for filename in os.listdir(src_img_folder): + # Check if the image is png, jpg or webp etc... + if not filename.endswith(img_exts): + # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.) + shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) + continue + + # Load image + img = cv2.imread(os.path.join(src_img_folder, filename)) + + base, _ = os.path.splitext(filename) + for max_resolution in max_resolutions: + # Calculate max_pixels from max_resolution string + max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) + + # Calculate current number of pixels + current_pixels = img.shape[0] * img.shape[1] + + # Check if the image needs resizing + if current_pixels > max_pixels: + # Calculate scaling factor + scale_factor = max_pixels / current_pixels + + # Calculate new dimensions + new_height = int(img.shape[0] * math.sqrt(scale_factor)) + new_width = int(img.shape[1] * math.sqrt(scale_factor)) + + # Resize image + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + new_height, new_width = img.shape[0:2] + + # Calculate the new height and width that are divisible by divisible_by (with/without resizing) + new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by + new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by + + # Center crop the image to the calculated dimensions + y = int((img.shape[0] - new_height) / 2) + x = int((img.shape[1] - new_width) / 2) + img = img[y:y + new_height, x:x + new_width] + + # Split filename into base and extension + new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') + + # Save resized image in dst_img_folder + cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) + proc = "Resized" if current_pixels > max_pixels else "Saved" + print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") + + # If other files with same basename, copy them with resolution suffix + if copy_associated_files: + asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*")) + for asoc_file in asoc_files: + ext = os.path.splitext(asoc_file)[1] + if ext in img_exts: + continue + for max_resolution in max_resolutions: + new_asoc_file = base + '+' + max_resolution + ext + print(f"Copy {asoc_file} as {new_asoc_file}") + shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) + + +def main(): + parser = argparse.ArgumentParser( + description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします') + parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') + parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ') + parser.add_argument('--max_resolution', type=str, + help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") + parser.add_argument('--divisible_by', type=int, + help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) + parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], + default='area', help='Interpolation method for resizing / リサイズ時の補完方法') + parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') + parser.add_argument('--copy_associated_files', action='store_true', + help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') + + args = parser.parse_args() + resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, + args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) + + +if __name__ == '__main__': + main() diff --git a/tools/resize_images_to_resolutions.py b/tools/resize_images_to_resolutions.py index 5492f1c..3e6f87d 100644 --- a/tools/resize_images_to_resolutions.py +++ b/tools/resize_images_to_resolutions.py @@ -4,13 +4,10 @@ import argparse import shutil import math -def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2): +def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=1, caption_extension=''): # Split the max_resolution string by "," and strip any whitespaces max_resolutions = [res.strip() for res in max_resolution.split(',')] - # # Calculate max_pixels from max_resolution string - # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) - # Create destination folder if it does not exist if not os.path.exists(dst_img_folder): os.makedirs(dst_img_folder) @@ -20,7 +17,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Check if the image is png, jpg or webp if not filename.endswith(('.png', '.jpg', '.webp')): # Copy the file to the destination folder if not png, jpg or webp - shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) + # shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) continue # Load image @@ -42,8 +39,8 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_height = int(img.shape[0] * math.sqrt(scale_factor)) new_width = int(img.shape[1] * math.sqrt(scale_factor)) - # Resize image - img = cv2.resize(img, (new_width, new_height)) + # Resize image using area interpolation (best when downsampling) + img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA) # Calculate the new height and width that are divisible by divisible_by new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by @@ -57,7 +54,11 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Split filename into base and extension base, ext = os.path.splitext(filename) new_filename = base + '+' + max_resolution + '.jpg' - + + # copy caption file with right name if one exist + if os.path.exists(os.path.join(src_img_folder, base + caption_extension)): + shutil.copy(os.path.join(src_img_folder, base + caption_extension), os.path.join(dst_img_folder, new_filename + caption_extension)) + # Save resized image in dst_img_folder cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") @@ -69,8 +70,9 @@ def main(): parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images') parser.add_argument('--max_resolution', type=str, help='Maximum resolution(s) in the format "512x512,448x448,384x384, etc, etc"', default="512x512,448x448,384x384") parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=1) + parser.add_argument('--caption_extension', type=str, help='Extension of caption files to copy with resized images"', default=".txt") args = parser.parse_args() - resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution) + resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, args.divisible_by, args.caption_extension) if __name__ == '__main__': main() \ No newline at end of file diff --git a/train_db.py b/train_db.py index d1bbc07..c210767 100644 --- a/train_db.py +++ b/train_db.py @@ -38,8 +38,13 @@ def train(args): args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps, args.bucket_no_upscale, args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + if args.no_token_padding: train_dataset.disable_token_padding() + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) + train_dataset.make_buckets() if args.debug_dataset: @@ -203,6 +208,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.set_current_epoch(epoch + 1) # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -327,7 +333,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, False) + train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) train_util.add_sd_saving_arguments(parser) diff --git a/train_network.py b/train_network.py index 3e8f4e7..bb3159f 100644 --- a/train_network.py +++ b/train_network.py @@ -120,18 +120,22 @@ def train(args): print("Use DreamBooth method.") train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.bucket_reso_steps, args.bucket_no_upscale, + args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) else: print("Train with captions.") train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, + args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.dataset_repeats, args.debug_dataset) + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) + train_dataset.make_buckets() if args.debug_dataset: @@ -376,6 +380,8 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.set_current_epoch(epoch + 1) + metadata["ss_epoch"] = str(epoch+1) network.on_epoch_start(text_encoder, unet) @@ -509,7 +515,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True) + train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 7a8370c..ba2e714 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -235,7 +235,7 @@ def train(args): text_encoder, optimizer, train_dataloader, lr_scheduler) index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - print(len(index_no_updates), torch.sum(index_no_updates)) + # print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -296,6 +296,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.set_current_epoch(epoch + 1) text_encoder.train() @@ -383,8 +384,8 @@ def train(args): accelerator.wait_for_everyone() updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() - d = updated_embs - bef_epo_embs - print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) + # d = updated_embs - bef_epo_embs + # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) if args.save_every_n_epochs is not None: model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name @@ -478,7 +479,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True) + train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],