From 7249b0baa835c133bf83fdae8319c6fc9daf0603 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 6 Mar 2023 19:15:02 -0500 Subject: [PATCH] Update to latest sd-script release add gui support for sample config --- README-ja.md | 12 - README.md | 16 + config_README-ja.md | 279 +++ dreambooth_gui.py | 39 +- fine_tune.py | 60 +- fine_tune_README_ja.md | 4 +- finetune_gui.py | 38 +- gen_img_diffusers.py | 92 +- kohya_gui.py | 7 +- library/config_util.py | 527 ++++ library/sampler_gui.py | 96 + library/train_util.py | 4639 +++++++++++++++--------------------- lora_gui.py | 49 +- requirements.txt | 2 + textual_inversion_gui.py | 39 +- train_README-ja.md | 619 +++++ train_db.py | 54 +- train_network.py | 233 +- train_textual_inversion.py | 90 +- 19 files changed, 4053 insertions(+), 2842 deletions(-) create mode 100644 config_README-ja.md create mode 100644 library/config_util.py create mode 100644 library/sampler_gui.py create mode 100644 train_README-ja.md diff --git a/README-ja.md b/README-ja.md index 319efe9..064464c 100644 --- a/README-ja.md +++ b/README-ja.md @@ -38,18 +38,6 @@ PowerShellを使う場合、venvを使えるようにするためには以下の - 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。 - 管理者のPowerShellを閉じます。 -## Ubuntu環境でのインストール - -``` -git clone https://github.com/kohya-ss/sd-scripts.git -cd sd-scripts -bash ubuntu_setup.sh -``` - -をコマンドプロンプトで実行し、tkをインストールし、accelerateの質問をWindowsと同じように答えます。 - -`./gui.sh`でGUIを実行します。 - ## Windows環境でのインストール 以下の例ではPyTorchは1.12.1/CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください。 diff --git a/README.md b/README.md index 93e08e6..588e66e 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,22 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/03/05 (v21.2.0): + - There may be problems due to major changes. If you cannot revert back to a previous version when problems occur (`git checkout `). + - Dependencies are updated, Please [upgrade](#upgrade) the repo. + - Add detail dataset config feature by extra config file. Thanks to fur0ut0 for this great contribution! + - Documentation is [here](https://github-com.translate.goog/kohya-ss/sd-scripts/blob/main/config_README-ja.md) (only in Japanese currently.) + - Specify `.toml` file with `--dataset_config` option. + - The options supported under the previous release can be used as is instead of the `.toml` config file. + - There might be bugs due to the large scale of update, please report any problems if you find at https://github.com/kohya-ss/sd-scripts/issues. + - Add feature to generate sample images in the middle of training for each training scripts. + - `--sample_every_n_steps` and `--sample_every_n_epochs` options: frequency to generate. + - `--sample_prompts` option: the file contains prompts (each line generates one image.) + - The prompt is subset of `gen_img_diffusers.py`. The prompt options `w, h, d, l, s, n` are supported. + - `--sample_sampler` option: sampler (scheduler) for generating, such as ddim or k_euler. See help for useable samplers. + - Add `--tokenizer_cache_dir` to each training and generation scripts to cache Tokenizer locally from Diffusers. + - Scripts will support offline training/generation after caching. + - Support letents upscaling for highres. fix, and VAE batch size in `gen_img_diffusers.py` (no documentation yet.) * 2023/03/05 (v21.1.5): - Add replace underscore with space option to WD14 captioning. Thanks @sALTaccount! - Improve how custom preset is set and handles. diff --git a/config_README-ja.md b/config_README-ja.md new file mode 100644 index 0000000..7f2b6c4 --- /dev/null +++ b/config_README-ja.md @@ -0,0 +1,279 @@ +For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. + +`--dataset_config` で渡すことができる設定ファイルに関する説明です。 + +## 概要 + +設定ファイルを渡すことにより、ユーザが細かい設定を行えるようにします。 + +* 複数のデータセットが設定可能になります + * 例えば `resolution` をデータセットごとに設定して、それらを混合して学習できます。 + * DreamBooth の手法と fine tuning の手法の両方に対応している学習方法では、DreamBooth 方式と fine tuning 方式のデータセットを混合することが可能です。 +* サブセットごとに設定を変更することが可能になります + * データセットを画像ディレクトリ別またはメタデータ別に分割したものがサブセットです。いくつかのサブセットが集まってデータセットを構成します。 + * `keep_tokens` や `flip_aug` 等のオプションはサブセットごとに設定可能です。一方、`resolution` や `batch_size` といったオプションはデータセットごとに設定可能で、同じデータセットに属するサブセットでは値が共通になります。詳しくは後述します。 + +設定ファイルの形式は JSON か TOML を利用できます。記述のしやすさを考えると [TOML](https://toml.io/ja/v1.0.0-rc.2) を利用するのがオススメです。以下、TOML の利用を前提に説明します。 + +TOML で記述した設定ファイルの例です。 + +```toml +[general] +shuffle_caption = true +caption_extension = '.txt' +keep_tokens = 1 + +# これは DreamBooth 方式のデータセット +[[datasets]] +resolution = 512 +batch_size = 4 +keep_tokens = 2 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + class_tokens = 'hoge girl' + # このサブセットは keep_tokens = 2 (所属する datasets の値が使われる) + + [[datasets.subsets]] + image_dir = 'C:\fuga' + class_tokens = 'fuga boy' + keep_tokens = 3 + + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' + class_tokens = 'human' + keep_tokens = 1 + +# これは fine tuning 方式のデータセット +[[datasets]] +resolution = [768, 768] +batch_size = 2 + + [[datasets.subsets]] + image_dir = 'C:\piyo' + metadata_file = 'C:\piyo\piyo_md.json' + # このサブセットは keep_tokens = 1 (general の値が使われる) +``` + +この例では、3 つのディレクトリを DreamBooth 方式のデータセットとして 512x512 (batch size 4) で学習させ、1 つのディレクトリを fine tuning 方式のデータセットとして 768x768 (batch size 2) で学習させることになります。 + +## データセット・サブセットに関する設定 + +データセット・サブセットに関する設定は、登録可能な箇所がいくつかに分かれています。 + +* `[general]` + * 全データセットまたは全サブセットに適用されるオプションを指定する箇所です。 + * データセットごとの設定及びサブセットごとの設定に同名のオプションが存在していた場合には、データセット・サブセットごとの設定が優先されます。 +* `[[datasets]]` + * `datasets` はデータセットに関する設定の登録箇所になります。各データセットに個別に適用されるオプションを指定する箇所です。 + * サブセットごとの設定が存在していた場合には、サブセットごとの設定が優先されます。 +* `[[datasets.subsets]]` + * `datasets.subsets` はサブセットに関する設定の登録箇所になります。各サブセットに個別に適用されるオプションを指定する箇所です。 + +先程の例における、画像ディレクトリと登録箇所の対応に関するイメージ図です。 + +``` +C:\ +├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐ +├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general] +├─ reg -> [[datasets.subsets]] No.3 ┘ | +└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘ +``` + +画像ディレクトリがそれぞれ1つの `[[datasets.subsets]]` に対応しています。そして `[[datasets.subsets]]` が1つ以上組み合わさって1つの `[[datasets]]` を構成します。`[general]` には全ての `[[datasets]]`, `[[datasets.subsets]]` が属します。 + +登録箇所ごとに指定可能なオプションは異なりますが、同名のオプションが指定された場合は下位の登録箇所にある値が優先されます。先程の例の `keep_tokens` オプションの扱われ方を確認してもらうと理解しやすいかと思います。 + +加えて、学習方法が対応している手法によっても指定可能なオプションが変化します。 + +* DreamBooth 方式専用のオプション +* fine tuning 方式専用のオプション +* caption dropout の手法が使える場合のオプション + +DreamBooth の手法と fine tuning の手法の両方とも利用可能な学習方法では、両者を併用することができます。 +併用する際の注意点として、DreamBooth 方式なのか fine tuning 方式なのかはデータセット単位で判別を行っているため、同じデータセット中に DreamBooth 方式のサブセットと fine tuning 方式のサブセットを混在させることはできません。 +つまり、これらを併用したい場合には異なる方式のサブセットが異なるデータセットに所属するように設定する必要があります。 + +プログラムの挙動としては、後述する `metadata_file` オプションが存在していたら fine tuning 方式のサブセットだと判断します。 +そのため、同一のデータセットに所属するサブセットについて言うと、「全てが `metadata_file` オプションを持つ」か「全てが `metadata_file` オプションを持たない」かのどちらかになっていれば問題ありません。 + +以下、利用可能なオプションを説明します。コマンドライン引数と名称が同一のオプションについては、基本的に説明を割愛します。他の README を参照してください。 + +### 全学習方法で共通のオプション + +学習方法によらずに指定可能なオプションです。 + +#### データセット向けオプション + +データセットの設定に関わるオプションです。`datasets.subsets` には記述できません。 + +| オプション名 | 設定例 | `[general]` | `[[datasets]]` | +| ---- | ---- | ---- | ---- | +| `batch_size` | `1` | o | o | +| `bucket_no_upscale` | `true` | o | o | +| `bucket_reso_steps` | `64` | o | o | +| `enable_bucket` | `true` | o | o | +| `max_bucket_reso` | `1024` | o | o | +| `min_bucket_reso` | `128` | o | o | +| `resolution` | `256`, `[512, 512]` | o | o | + +* `batch_size` + * コマンドライン引数の `--train_batch_size` と同等です。 + +これらの設定はデータセットごとに固定です。 +つまり、データセットに所属するサブセットはこれらの設定を共有することになります。 +例えば解像度が異なるデータセットを用意したい場合は、上に挙げた例のように別々のデータセットとして定義すれば別々の解像度を設定可能です。 + +#### サブセット向けオプション + +サブセットの設定に関わるオプションです。 + +| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `color_aug` | `false` | o | o | o | +| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o | +| `flip_aug` | `true` | o | o | o | +| `keep_tokens` | `2` | o | o | o | +| `num_repeats` | `10` | o | o | o | +| `random_crop` | `false` | o | o | o | +| `shuffle_caption` | `true` | o | o | o | + +* `num_repeats` + * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 + +### DreamBooth 方式専用のオプション + +DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。 + +#### サブセット向けオプション + +DreamBooth 方式のサブセットの設定に関わるオプションです。 + +| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `‘C:\hoge’` | - | - | o(必須) | +| `caption_extension` | `".txt"` | o | o | o | +| `class_tokens` | `“sks girl”` | - | - | o | +| `is_reg` | `false` | - | - | o | + +まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。 + +* `image_dir` + * 画像ディレクトリのパスを指定します。指定必須オプションです。 + * 画像はディレクトリ直下に置かれている必要があります。 +* `class_tokens` + * クラストークンを設定します。 + * 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。 +* `is_reg` + * サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。 + +### fine tuning 方式専用のオプション + +fine tuning 方式のオプションは、サブセット向けオプションのみ存在します。 + +#### サブセット向けオプション + +fine tuning 方式のサブセットの設定に関わるオプションです。 + +| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `‘C:\hoge’` | - | - | o | +| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必須) | + +* `image_dir` + * 画像ディレクトリのパスを指定します。DreamBooth の手法の方とは異なり指定は必須ではありませんが、設定することを推奨します。 + * 指定する必要がない状況としては、メタデータファイルの生成時に `--full_path` を付与して実行していた場合です。 + * 画像はディレクトリ直下に置かれている必要があります。 +* `metadata_file` + * サブセットで利用されるメタデータファイルのパスを指定します。指定必須オプションです。 + * コマンドライン引数の `--in_json` と同等です。 + * サブセットごとにメタデータファイルを指定する必要がある仕様上、ディレクトリを跨いだメタデータを1つのメタデータファイルとして作成することは避けた方が良いでしょう。画像ディレクトリごとにメタデータファイルを用意し、それらを別々のサブセットとして登録することを強く推奨します。 + +### caption dropout の手法が使える場合に指定可能なオプション + +caption dropout の手法が使える場合のオプションは、サブセット向けオプションのみ存在します。 +DreamBooth 方式か fine tuning 方式かに関わらず、caption dropout に対応している学習方法であれば指定可能です。 + +#### サブセット向けオプション + +caption dropout が使えるサブセットの設定に関わるオプションです。 + +| オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | +| `caption_dropout_every_n_epochs` | o | o | o | +| `caption_dropout_rate` | o | o | o | +| `caption_tag_dropout_rate` | o | o | o | + +## 重複したサブセットが存在する時の挙動 + +DreamBooth 方式のデータセットの場合、その中にある `image_dir` が同一のサブセットは重複していると見なされます。 +fine tuning 方式のデータセットの場合は、その中にある `metadata_file` が同一のサブセットは重複していると見なされます。 +データセット中に重複したサブセットが存在する場合、2個目以降は無視されます。 + +一方、異なるデータセットに所属している場合は、重複しているとは見なされません。 +例えば、以下のように同一の `image_dir` を持つサブセットを別々のデータセットに入れた場合には、重複していないと見なします。 +これは、同じ画像でも異なる解像度で学習したい場合に役立ちます。 + +```toml +# 別々のデータセットに存在している場合は重複とは見なされず、両方とも学習に使われる + +[[datasets]] +resolution = 512 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 768 + + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +## コマンドライン引数との併用 + +設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。 + +以下に挙げるコマンドライン引数のオプションは、設定ファイルを渡した場合には無視されます。 + +* `--train_data_dir` +* `--reg_data_dir` +* `--in_json` + +以下に挙げるコマンドライン引数のオプションは、コマンドライン引数と設定ファイルで同時に指定された場合、コマンドライン引数の値よりも設定ファイルの値が優先されます。特に断りがなければ同名のオプションとなります。 + +| コマンドライン引数のオプション | 優先される設定ファイルのオプション | +| ---------------------------------- | ---------------------------------- | +| `--bucket_no_upscale` | | +| `--bucket_reso_steps` | | +| `--caption_dropout_every_n_epochs` | | +| `--caption_dropout_rate` | | +| `--caption_extension` | | +| `--caption_tag_dropout_rate` | | +| `--color_aug` | | +| `--dataset_repeats` | `num_repeats` | +| `--enable_bucket` | | +| `--face_crop_aug_range` | | +| `--flip_aug` | | +| `--keep_tokens` | | +| `--min_bucket_reso` | | +| `--random_crop` | | +| `--resolution` | | +| `--shuffle_caption` | | +| `--train_batch_size` | `batch_size` | + +## エラーの手引き + +現在、外部ライブラリを利用して設定ファイルの記述が正しいかどうかをチェックしているのですが、整備が行き届いておらずエラーメッセージがわかりづらいという問題があります。 +将来的にはこの問題の改善に取り組む予定です。 + +次善策として、頻出のエラーとその対処法について載せておきます。 +正しいはずなのにエラーが出る場合、エラー内容がどうしても分からない場合は、バグかもしれないのでご連絡ください。 + +* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されていないというエラーです。指定を忘れているか、オプション名を間違って記述している可能性が高いです。 + * `...` の箇所にはエラーが発生した場所が載っています。例えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のようなエラーが出たら、0 番目の `datasets` 中の 0 番目の `subsets` の設定に `image_dir` が存在しないということになります。 +* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。 +* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。 + + diff --git a/dreambooth_gui.py b/dreambooth_gui.py index c0f4472..6f353d8 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -36,6 +36,7 @@ from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) from library.utilities import utilities_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 @@ -100,6 +101,10 @@ def save_configuration( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -201,6 +206,10 @@ def open_configuration( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -281,6 +290,10 @@ def train_model( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -454,8 +467,15 @@ def train_model( noise_offset=noise_offset, ) + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) + print(run_cmd) - + # Run the command if os.name == 'posix': os.system(run_cmd) @@ -654,11 +674,14 @@ def dreambooth_tab( inputs=[color_aug], outputs=[cache_latents], ) - # optimizer.change( - # set_legacy_8bitadam, - # inputs=[optimizer, use_8bit_adam], - # outputs=[optimizer, use_8bit_adam], - # ) + + ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) = sample_gradio_config() + with gr.Tab('Tools'): gr.Markdown( 'This section provide Dreambooth tools to help setup your dataset...' @@ -740,6 +763,10 @@ def dreambooth_tab( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ] button_open_config.click( diff --git a/fine_tune.py b/fine_tune.py index 426fb09..1255759 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -13,7 +13,11 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util - +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) def collate_fn(examples): return examples[0] @@ -30,25 +34,36 @@ def train(args): tokenizer = train_util.load_tokenizer(args) - train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, - args.dataset_repeats, args.debug_dataset) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) + else: + user_config = { + "datasets": [{ + "subsets": [{ + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + }] + }] + } - # 学習データのdropout率を設定する - train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) - - train_dataset.make_buckets() + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: - train_util.debug_dataset(train_dataset) + train_util.debug_dataset(train_dataset_group) return - if len(train_dataset) == 0: + if len(train_dataset_group) == 0: print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") return + if cache_latents: + assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + # acceleratorを準備する print("prepare accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) @@ -109,7 +124,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.cache_latents(vae) + train_dataset_group.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -155,7 +170,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -199,7 +214,7 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset.num_train_images}") + print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -218,7 +233,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset.set_current_epoch(epoch + 1) + train_dataset_group.set_current_epoch(epoch + 1) for m in training_models: m.train() @@ -282,17 +297,13 @@ def train(args): progress_bar.update(1) global_step += 1 + train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - logs = {"avr_loss": loss_total / (step+1)} if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value - # print(lr_scheduler.optimizers) logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] - logs["d"] = lr_scheduler.optimizers[0].param_groups[0]['d'] - logs["lrD"] = lr_scheduler.optimizers[0].param_groups[0]['lr'] - logs["gsq_weighted"] = lr_scheduler.optimizers[0].param_groups[0]['gsq_weighted'] - accelerator.log(logs, step=global_step) # TODO moving averageにする @@ -315,6 +326,8 @@ def train(args): train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + is_main_process = accelerator.is_main_process if is_main_process: unet = unwrap_model(unet) @@ -342,6 +355,7 @@ if __name__ == '__main__': train_util.add_training_arguments(parser, False) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) parser.add_argument("--diffusers_xformers", action='store_true', help='use xformers by diffusers / Diffusersでxformersを使用する') diff --git a/fine_tune_README_ja.md b/fine_tune_README_ja.md index f763490..9dcd34a 100644 --- a/fine_tune_README_ja.md +++ b/fine_tune_README_ja.md @@ -324,7 +324,7 @@ __※引数を都度書き換えて、別のメタデータファイルに書き ## 学習の実行 たとえば以下のように実行します。以下は省メモリ化のための設定です。 ``` -accelerate launch --num_cpu_threads_per_process 8 fine_tune.py +accelerate launch --num_cpu_threads_per_process 1 fine_tune.py --pretrained_model_name_or_path=model.ckpt --in_json meta_lat.json --train_data_dir=train_data @@ -336,7 +336,7 @@ accelerate launch --num_cpu_threads_per_process 8 fine_tune.py --save_every_n_epochs=4 ``` -accelerateのnum_cpu_threads_per_processにはCPUのコア数を指定するとよいようです。 +accelerateのnum_cpu_threads_per_processには通常は1を指定するとよいようです。 pretrained_model_name_or_pathに学習対象のモデルを指定します(Stable DiffusionのcheckpointかDiffusersのモデル)。Stable Diffusionのcheckpointは.ckptと.safetensorsに対応しています(拡張子で自動判定)。 diff --git a/finetune_gui.py b/finetune_gui.py index 38a11fb..50ff678 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -27,6 +27,7 @@ from library.tensorboard_gui import ( stop_tensorboard, ) from library.utilities import utilities_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -98,6 +99,10 @@ def save_configuration( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -205,6 +210,10 @@ def open_config_file( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -291,6 +300,10 @@ def train_model( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): # create caption json file if generate_caption_database: @@ -446,8 +459,15 @@ def train_model( noise_offset=noise_offset, ) + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) + print(run_cmd) - + # Run the command if os.name == 'posix': os.system(run_cmd) @@ -656,11 +676,13 @@ def finetune_tab(): inputs=[color_aug], outputs=[cache_latents], # Not applicable to fine_tune.py ) - # optimizer.change( - # set_legacy_8bitadam, - # inputs=[optimizer, use_8bit_adam], - # outputs=[optimizer, use_8bit_adam], - # ) + + ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) = sample_gradio_config() button_run = gr.Button('Train model', variant='primary') @@ -737,6 +759,10 @@ def finetune_tab(): optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ] button_run.click(train_model, inputs=settings_list) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a2d5b94..6bab0bb 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -80,6 +80,7 @@ from PIL import Image from PIL.PngImagePlugin import PngInfo import library.model_util as model_util +import library.train_util as train_util import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo @@ -589,6 +590,8 @@ class PipelineLike(): latents: Optional[torch.FloatTensor] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, # return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, @@ -680,6 +683,9 @@ class PipelineLike(): else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + vae_batch_size = batch_size if vae_batch_size is None else ( + int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -793,7 +799,6 @@ class PipelineLike(): latents_dtype = text_embeddings.dtype init_latents_orig = None mask = None - noise = None if init_image is None: # get the initial random noise unless the user supplied it @@ -825,6 +830,8 @@ class PipelineLike(): if isinstance(init_image[0], PIL.Image.Image): init_image = [preprocess_image(im) for im in init_image] init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) # mask image to tensor if mask_image is not None: @@ -835,9 +842,24 @@ class PipelineLike(): # encode the init image into latents and scale the latents init_image = init_image.to(device=self.device, dtype=latents_dtype) - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents + if init_image.size()[2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size] + if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = 0.18215 * init_latents + if len(init_latents) == 1: init_latents = init_latents.repeat((batch_size, 1, 1, 1)) init_latents_orig = init_latents @@ -932,8 +954,19 @@ class PipelineLike(): if is_cancelled_callback is not None and is_cancelled_callback(): return None + if return_latents: + return (latents, False) + latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + if vae_batch_size >= batch_size: + image = self.vae.decode(latents).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample) + image = torch.cat(images) image = (image / 2 + 0.5).clamp(0, 1) @@ -1820,7 +1853,7 @@ def preprocess_mask(mask): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS) + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? @@ -1862,6 +1895,7 @@ class BatchDataExt(NamedTuple): class BatchData(NamedTuple): + return_latents: bool base: BatchDataBase ext: BatchDataExt @@ -1930,10 +1964,7 @@ def main(args): # tokenizerを読み込む print("loading tokenizer") if use_stable_diffusion_format: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) + tokenizer = train_util.load_tokenizer(args) # schedulerを用意する sched_init_args = {} @@ -2296,9 +2327,9 @@ def main(args): # highres_fixの処理 if highres_fix and not highres_1st: # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - print("process 1st stage1") + print("process 1st stage") batch_1st = [] - for base, ext in batch: + for _, base, ext in batch: width_1st = int(ext.width * args.highres_fix_scale + .5) height_1st = int(ext.height * args.highres_fix_scale + .5) width_1st = width_1st - width_1st % 32 @@ -2306,20 +2337,29 @@ def main(args): ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls) - batch_1st.append(BatchData(base, ext_1st)) + batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage1") + print("process 2nd stage") + if args.highres_fix_latents_upscaling: + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True) + images_1st = images_1st.to(org_dtype) + batch_2nd = [] for i, (bd, image) in enumerate(zip(batch, images_1st)): - image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 - bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext) + if not args.highres_fix_latents_upscaling: + image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext) batch_2nd.append(bd_2nd) batch = batch_2nd # このバッチの情報を取り出す - (step_first, _, _, _, init_image, mask_image, _, guide_image), \ + return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \ (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) @@ -2353,7 +2393,7 @@ def main(args): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) @@ -2413,8 +2453,10 @@ def main(args): n.set_multiplier(m) images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, - output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] - if highres_1st and not args.highres_fix_save_1st: + output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, return_latents=return_latents, + clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] + if highres_1st and not args.highres_fix_save_1st: # return images or latents return images # save image @@ -2612,9 +2654,9 @@ def main(args): print("Use previous image as guide image.") guide_image = prev_image - b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None)) - if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? process_batch(batch_data, highres_fix) batch_data.clear() @@ -2658,6 +2700,8 @@ if __name__ == '__main__': parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument("--vae_batch_size", type=float, default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率") parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") parser.add_argument('--sampler', type=str, default='ddim', choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver', @@ -2669,6 +2713,8 @@ if __name__ == '__main__': parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument("--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") + parser.add_argument("--tokenizer_cache_dir", type=str, default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)") # parser.add_argument("--replace_clip_l14_336", action='store_true', # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") parser.add_argument("--seed", type=int, default=None, @@ -2713,6 +2759,8 @@ if __name__ == '__main__': help="1st stage steps for highres fix / highres fixの最初のステージのステップ数") parser.add_argument("--highres_fix_save_1st", action='store_true', help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") + parser.add_argument("--highres_fix_latents_upscaling", action='store_true', + help="use latents upscaling for highres fix / highres fixでlatentで拡大する") parser.add_argument("--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する") diff --git a/kohya_gui.py b/kohya_gui.py index a0b8222..d643228 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -53,7 +53,7 @@ def UI(**kwargs): inbrowser = kwargs.get('inbrowser', False) share = kwargs.get('share', False) server_name = kwargs.get('listen') - + launch_kwargs['server_name'] = server_name if username and password: launch_kwargs['auth'] = (username, password) @@ -70,7 +70,10 @@ if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() parser.add_argument( - '--listen', type=str, default='127.0.0.1', help='IP to listen on for connections to Gradio' + '--listen', + type=str, + default='127.0.0.1', + help='IP to listen on for connections to Gradio', ) parser.add_argument( '--username', type=str, default='', help='Username for authentication' diff --git a/library/config_util.py b/library/config_util.py new file mode 100644 index 0000000..e62bfb8 --- /dev/null +++ b/library/config_util.py @@ -0,0 +1,527 @@ +import argparse +from dataclasses import ( + asdict, + dataclass, +) +import functools +from textwrap import dedent, indent +import json +from pathlib import Path +# from toolz import curry +from typing import ( + List, + Optional, + Sequence, + Tuple, + Union, +) + +import toml +import voluptuous +from voluptuous import ( + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, +) +from transformers import CLIPTokenizer + +from . import train_util +from .train_util import ( + DreamBoothSubset, + FineTuningSubset, + DreamBoothDataset, + FineTuningDataset, + DatasetGroup, +) + + +def add_config_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + +# TODO: inherit Params class in Subset, Dataset + +@dataclass +class BaseSubsetParams: + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + keep_tokens: int = 0 + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + +@dataclass +class DreamBoothSubsetParams(BaseSubsetParams): + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + +@dataclass +class FineTuningSubsetParams(BaseSubsetParams): + metadata_file: Optional[str] = None + +@dataclass +class BaseDatasetParams: + tokenizer: CLIPTokenizer = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + debug_dataset: bool = False + +@dataclass +class DreamBoothDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + +@dataclass +class FineTuningDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + +@dataclass +class SubsetBlueprint: + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + +@dataclass +class DatasetBlueprint: + is_dreambooth: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + +@dataclass +class DatasetGroupBlueprint: + datasets: Sequence[DatasetBlueprint] +@dataclass +class Blueprint: + dataset_group: DatasetGroupBlueprint + + +class ConfigSanitizer: + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) + + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + } + + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + } + + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } + + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None: + assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) + + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) + + if support_dreambooth and support_finetuning: + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) + + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + if all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。") + + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + self.dataset_schema = self.db_dataset_schema + else: + self.dataset_schema = self.ft_dataset_schema + + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.user_config_validator = Schema({ + "general": self.general_schema, + "datasets": [self.dataset_schema], + }) + + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) + + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: エラー発生時のメッセージをわかりやすくする + print("Invalid user config / ユーザ設定の形式が正しくないようです") + raise + + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged + + +class BlueprintGenerator: + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = { + } + + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer + + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()} + + general_config = sanitized_user_config.get("general", {}) + + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + if is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams + + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks(subset_params_klass, + [subset_config, dataset_config, general_config, argparse_config, runtime_params]) + subset_blueprints.append(SubsetBlueprint(params)) + + params = self.generate_params_by_fallbacks(dataset_params_klass, + [dataset_config, general_config, argparse_config, runtime_params]) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints)) + + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + + return Blueprint(dataset_group_blueprint) + + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() + + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + + return param_klass(**params) + + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value = None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value + + return default_value + + +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): + datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) + + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) + + if dataset.enable_bucket: + info += indent(dedent(f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent(dedent(f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + else: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + print(info) + + # make buckets first because it determines the length of dataset + for i, dataset in enumerate(datasets): + print(f"[Dataset {i}]") + dataset.make_buckets() + + return DatasetGroup(datasets) + + +def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split('_') + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") + return 0, "" + caption_by_folder = '_'.join(tokens[1:]) + return n_repeats, caption_by_folder + + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) + + return subsets_config + + +def load_user_config(file: str) -> dict: + file: Path = Path(file) + if not file.is_file(): + raise ValueError(f"file not found / ファイルが見つかりません: {file}") + + if file.name.lower().endswith('.json'): + try: + config = json.load(file) + except Exception: + print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + raise + elif file.name.lower().endswith('.toml'): + try: + config = toml.load(file) + except Exception: + print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + + return config + + +# for config test +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() + + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + + print("[argparse_namespace]") + print(vars(argparse_namespace)) + + user_config = load_user_config(config_args.dataset_config) + + print("\n[user_config]") + print(user_config) + + sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) + + print("\n[sanitized_user_config]") + print(sanitized_user_config) + + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + + print("\n[blueprint]") + print(blueprint) diff --git a/library/sampler_gui.py b/library/sampler_gui.py new file mode 100644 index 0000000..9f0326d --- /dev/null +++ b/library/sampler_gui.py @@ -0,0 +1,96 @@ +import tempfile +import gradio as gr +from easygui import msgbox + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +document_symbol = '\U0001F4C4' # 📄 + + +### +### Gradio common sampler GUI section +### + + +def sample_gradio_config(): + with gr.Accordion('Sample images config', open=False): + with gr.Row(): + sample_every_n_steps = gr.Number( + label='Sample every n steps', + value=0, + precision=0, + interactive=True, + ) + sample_every_n_epochs = gr.Number( + label='Sample every n epochs', + value=0, + precision=0, + interactive=True, + ) + sample_sampler = gr.Dropdown( + label='Sample sampler', + choices=[ + 'ddim', + 'pndm', + 'lms', + 'euler', + 'euler_a', + 'heun', + 'dpm_2', + 'dpm_2_a', + 'dpmsolver', + 'dpmsolver++', + 'dpmsingle', + 'k_lms', + 'k_euler', + 'k_euler_a', + 'k_dpm_2', + 'k_dpm_2_a', + ], + value='euler_a', + interactive=True, + ) + with gr.Row(): + sample_prompts = gr.Textbox( + lines=5, + label='Sample prompts', + interactive=True, + ) + return ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) + + +def run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, +): + run_cmd = '' + + if sample_every_n_epochs == 0 and sample_every_n_steps == 0: + return run_cmd + + # Create a temporary file and get its path + with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file: + # Write the contents of the variable to the file + temp_file.write(sample_prompts) + + # Get the path of the temporary file + sample_prompts_path = temp_file.name + + run_cmd += f' --sample_sampler={sample_sampler}' + run_cmd += f' --sample_prompts="{sample_prompts_path}"' + + if not sample_every_n_epochs == 0: + run_cmd += f' --sample_every_n_epochs="{sample_every_n_epochs}"' + + if not sample_every_n_steps == 0: + run_cmd += f' --sample_every_n_steps="{sample_every_n_steps}"' + + return run_cmd diff --git a/library/train_util.py b/library/train_util.py index 79fde4b..75176e1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,12 +3,19 @@ import argparse import importlib import json +import re import shutil import time -from typing import Dict, List, NamedTuple, Tuple -from typing import Optional, Union +from typing import ( + Dict, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) from accelerate import Accelerator -from torch.autograd.function import Function import glob import math import os @@ -25,7 +32,10 @@ from transformers import CLIPTokenizer import transformers import diffusers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION -from diffusers import DDPMScheduler, StableDiffusionPipeline +from diffusers import (StableDiffusionPipeline, DDPMScheduler, + EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, + KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler) import albumentations as albu import numpy as np from PIL import Image @@ -37,1364 +47,1091 @@ import safetensors.torch import library.model_util as model_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = 'openai/clip-vit-large-patch14' -V2_STABLE_DIFFUSION_PATH = 'stabilityai/stable-diffusion-2' # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ # checkpointファイル名 -EPOCH_STATE_NAME = '{}-{:06d}-state' -EPOCH_FILE_NAME = '{}-{:06d}' -EPOCH_DIFFUSERS_DIR_NAME = '{}-{:06d}' -LAST_STATE_NAME = '{}-state' -DEFAULT_EPOCH_NAME = 'epoch' -DEFAULT_LAST_OUTPUT_NAME = 'last' +EPOCH_STATE_NAME = "{}-{:06d}-state" +EPOCH_FILE_NAME = "{}-{:06d}" +EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" +LAST_STATE_NAME = "{}-state" +DEFAULT_EPOCH_NAME = "epoch" +DEFAULT_LAST_OUTPUT_NAME = "last" # region dataset -IMAGE_EXTENSIONS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp'] +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] # , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux? -class ImageInfo: - def __init__( - self, - image_key: str, - num_repeats: int, - caption: str, - is_reg: bool, - absolute_path: str, - ) -> None: - self.image_key: str = image_key - self.num_repeats: int = num_repeats - self.caption: str = caption - self.is_reg: bool = is_reg - self.absolute_path: str = absolute_path - self.image_size: Tuple[int, int] = None - self.resized_size: Tuple[int, int] = None - self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_npz_flipped: str = None +class ImageInfo(): + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: Tuple[int, int] = None + self.resized_size: Tuple[int, int] = None + self.bucket_reso: Tuple[int, int] = None + self.latents: torch.Tensor = None + self.latents_flipped: torch.Tensor = None + self.latents_npz: str = None + self.latents_npz_flipped: str = None -class BucketManager: - def __init__( - self, no_upscale, max_reso, min_size, max_size, reso_steps - ) -> None: - self.no_upscale = no_upscale - if max_reso is None: - self.max_reso = None - self.max_area = None +class BucketManager(): + def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: + self.no_upscale = no_upscale + if max_reso is None: + self.max_reso = None + self.max_area = None + else: + self.max_reso = max_reso + self.max_area = max_reso[0] * max_reso[1] + self.min_size = min_size + self.max_size = max_size + self.reso_steps = reso_steps + + self.resos = [] + self.reso_to_id = {} + self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key + + def add_image(self, reso, image): + bucket_id = self.reso_to_id[reso] + self.buckets[bucket_id].append(image) + + def shuffle(self): + for bucket in self.buckets: + random.shuffle(bucket) + + def sort(self): + # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す + sorted_resos = self.resos.copy() + sorted_resos.sort() + + sorted_buckets = [] + sorted_reso_to_id = {} + for i, reso in enumerate(sorted_resos): + bucket_id = self.reso_to_id[reso] + sorted_buckets.append(self.buckets[bucket_id]) + sorted_reso_to_id[reso] = i + + self.resos = sorted_resos + self.buckets = sorted_buckets + self.reso_to_id = sorted_reso_to_id + + def make_buckets(self): + resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) + self.set_predefined_resos(resos) + + def set_predefined_resos(self, resos): + # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく + self.predefined_resos = resos.copy() + self.predefined_resos_set = set(resos) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) + + def add_if_new_reso(self, reso): + if reso not in self.reso_to_id: + bucket_id = len(self.resos) + self.reso_to_id[reso] = bucket_id + self.resos.append(reso) + self.buckets.append([]) + # print(reso, bucket_id, len(self.buckets)) + + def round_to_steps(self, x): + x = int(x + .5) + return x - x % self.reso_steps + + def select_bucket(self, image_width, image_height): + aspect_ratio = image_width / image_height + if not self.no_upscale: + # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する + reso = (image_width, image_height) + if reso in self.predefined_resos_set: + pass + else: + ar_errors = self.predefined_aspect_ratios - aspect_ratio + predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの + reso = self.predefined_resos[predefined_bucket_id] + + ar_reso = reso[0] / reso[1] + if aspect_ratio > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + + resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) + # print("use predef", image_width, image_height, reso, resized_size) + else: + if image_width * image_height > self.max_area: + # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める + resized_width = math.sqrt(self.max_area * aspect_ratio) + resized_height = self.max_area / resized_width + assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" + + # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ + # 元のbucketingと同じロジック + b_width_rounded = self.round_to_steps(resized_width) + b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) + ar_width_rounded = b_width_rounded / b_height_in_wr + + b_height_rounded = self.round_to_steps(resized_height) + b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) + ar_height_rounded = b_width_in_hr / b_height_rounded + + # print(b_width_rounded, b_height_in_wr, ar_width_rounded) + # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + + if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): + resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5)) else: - self.max_reso = max_reso - self.max_area = max_reso[0] * max_reso[1] - self.min_size = min_size - self.max_size = max_size - self.reso_steps = reso_steps + resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded) + # print(resized_size) + else: + resized_size = (image_width, image_height) # リサイズは不要 - self.resos = [] - self.reso_to_id = {} - self.buckets = ( - [] - ) # 前処理時は (image_key, image)、学習時は image_key + # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) + bucket_width = resized_size[0] - resized_size[0] % self.reso_steps + bucket_height = resized_size[1] - resized_size[1] % self.reso_steps + # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) - def add_image(self, reso, image): - bucket_id = self.reso_to_id[reso] - self.buckets[bucket_id].append(image) + reso = (bucket_width, bucket_height) - def shuffle(self): - for bucket in self.buckets: - random.shuffle(bucket) + self.add_if_new_reso(reso) - def sort(self): - # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す - sorted_resos = self.resos.copy() - sorted_resos.sort() - - sorted_buckets = [] - sorted_reso_to_id = {} - for i, reso in enumerate(sorted_resos): - bucket_id = self.reso_to_id[reso] - sorted_buckets.append(self.buckets[bucket_id]) - sorted_reso_to_id[reso] = i - - self.resos = sorted_resos - self.buckets = sorted_buckets - self.reso_to_id = sorted_reso_to_id - - def make_buckets(self): - resos = model_util.make_bucket_resolutions( - self.max_reso, self.min_size, self.max_size, self.reso_steps - ) - self.set_predefined_resos(resos) - - def set_predefined_resos(self, resos): - # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく - self.predefined_resos = resos.copy() - self.predefined_resos_set = set(resos) - self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) - - def add_if_new_reso(self, reso): - if reso not in self.reso_to_id: - bucket_id = len(self.resos) - self.reso_to_id[reso] = bucket_id - self.resos.append(reso) - self.buckets.append([]) - # print(reso, bucket_id, len(self.buckets)) - - def round_to_steps(self, x): - x = int(x + 0.5) - return x - x % self.reso_steps - - def select_bucket(self, image_width, image_height): - aspect_ratio = image_width / image_height - if not self.no_upscale: - # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する - reso = (image_width, image_height) - if reso in self.predefined_resos_set: - pass - else: - ar_errors = self.predefined_aspect_ratios - aspect_ratio - predefined_bucket_id = np.abs( - ar_errors - ).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの - reso = self.predefined_resos[predefined_bucket_id] - - ar_reso = reso[0] / reso[1] - if aspect_ratio > ar_reso: # 横が長い→縦を合わせる - scale = reso[1] / image_height - else: - scale = reso[0] / image_width - - resized_size = ( - int(image_width * scale + 0.5), - int(image_height * scale + 0.5), - ) - # print("use predef", image_width, image_height, reso, resized_size) - else: - if image_width * image_height > self.max_area: - # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める - resized_width = math.sqrt(self.max_area * aspect_ratio) - resized_height = self.max_area / resized_width - assert ( - abs(resized_width / resized_height - aspect_ratio) < 1e-2 - ), 'aspect is illegal' - - # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ - # 元のbucketingと同じロジック - b_width_rounded = self.round_to_steps(resized_width) - b_height_in_wr = self.round_to_steps( - b_width_rounded / aspect_ratio - ) - ar_width_rounded = b_width_rounded / b_height_in_wr - - b_height_rounded = self.round_to_steps(resized_height) - b_width_in_hr = self.round_to_steps( - b_height_rounded * aspect_ratio - ) - ar_height_rounded = b_width_in_hr / b_height_rounded - - # print(b_width_rounded, b_height_in_wr, ar_width_rounded) - # print(b_width_in_hr, b_height_rounded, ar_height_rounded) - - if abs(ar_width_rounded - aspect_ratio) < abs( - ar_height_rounded - aspect_ratio - ): - resized_size = ( - b_width_rounded, - int(b_width_rounded / aspect_ratio + 0.5), - ) - else: - resized_size = ( - int(b_height_rounded * aspect_ratio + 0.5), - b_height_rounded, - ) - # print(resized_size) - else: - resized_size = ( - image_width, - image_height, - ) # リサイズは不要 - - # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) - bucket_width = resized_size[0] - resized_size[0] % self.reso_steps - bucket_height = resized_size[1] - resized_size[1] % self.reso_steps - # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) - - reso = (bucket_width, bucket_height) - - self.add_if_new_reso(reso) - - ar_error = (reso[0] / reso[1]) - aspect_ratio - return reso, resized_size, ar_error + ar_error = (reso[0] / reso[1]) - aspect_ratio + return reso, resized_size, ar_error class BucketBatchIndex(NamedTuple): - bucket_index: int - bucket_batch_size: int - batch_index: int + bucket_index: int + bucket_batch_size: int + batch_index: int +class AugHelper: + def __init__(self): + # prepare all possible augmentators + color_aug_method = albu.OneOf([ + albu.HueSaturationValue(8, 0, 0, p=.5), + albu.RandomGamma((95, 105), p=.5), + ], p=.33) + flip_aug_method = albu.HorizontalFlip(p=0.5) + + # key: (use_color_aug, use_flip_aug) + self.augmentors = { + (True, True): albu.Compose([ + color_aug_method, + flip_aug_method, + ], p=1.), + (True, False): albu.Compose([ + color_aug_method, + ], p=1.), + (False, True): albu.Compose([ + flip_aug_method, + ], p=1.), + (False, False): None + } + + def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: + return self.augmentors[(use_color_aug, use_flip_aug)] + + +class BaseSubset: + def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None: + self.image_dir = image_dir + self.num_repeats = num_repeats + self.shuffle_caption = shuffle_caption + self.keep_tokens = keep_tokens + self.color_aug = color_aug + self.flip_aug = flip_aug + self.face_crop_aug_range = face_crop_aug_range + self.random_crop = random_crop + self.caption_dropout_rate = caption_dropout_rate + self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs + self.caption_tag_dropout_rate = caption_tag_dropout_rate + + self.img_count = 0 + + +class DreamBoothSubset(BaseSubset): + def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + + super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, + face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + + self.is_reg = is_reg + self.class_tokens = class_tokens + self.caption_extension = caption_extension + + def __eq__(self, other) -> bool: + if not isinstance(other, DreamBoothSubset): + return NotImplemented + return self.image_dir == other.image_dir + +class FineTuningSubset(BaseSubset): + def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: + assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" + + super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, + face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + + self.metadata_file = metadata_file + + def __eq__(self, other) -> bool: + if not isinstance(other, FineTuningSubset): + return NotImplemented + return self.metadata_file == other.metadata_file + class BaseDataset(torch.utils.data.Dataset): - def __init__( - self, - tokenizer, - max_token_length, - shuffle_caption, - shuffle_keep_tokens, - resolution, - flip_aug: bool, - color_aug: bool, - face_crop_aug_range, - random_crop, - debug_dataset: bool, - ) -> None: - super().__init__() - self.tokenizer: CLIPTokenizer = tokenizer - self.max_token_length = max_token_length - self.shuffle_caption = shuffle_caption - self.shuffle_keep_tokens = shuffle_keep_tokens - # width/height is used when enable_bucket==False - self.width, self.height = ( - (None, None) if resolution is None else resolution - ) - self.face_crop_aug_range = face_crop_aug_range - self.flip_aug = flip_aug - self.color_aug = color_aug - self.debug_dataset = debug_dataset - self.random_crop = random_crop - self.token_padding_disabled = False - self.dataset_dirs_info = {} - self.reg_dataset_dirs_info = {} - self.tag_frequency = {} + def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None: + super().__init__() + self.tokenizer = tokenizer + self.max_token_length = max_token_length + # width/height is used when enable_bucket==False + self.width, self.height = (None, None) if resolution is None else resolution + self.debug_dataset = debug_dataset - self.enable_bucket = False - self.bucket_manager: BucketManager = None # not initialized - self.min_bucket_reso = None - self.max_bucket_reso = None - self.bucket_reso_steps = None - self.bucket_no_upscale = None - self.bucket_info = None # for metadata + self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] - self.tokenizer_max_length = ( - self.tokenizer.model_max_length - if max_token_length is None - else max_token_length + 2 - ) + self.token_padding_disabled = False + self.tag_frequency = {} - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ - self.dropout_rate: float = 0 - self.dropout_every_n_epochs: int = None - self.tag_dropout_rate: float = 0 + self.enable_bucket = False + self.bucket_manager: BucketManager = None # not initialized + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None + self.bucket_no_upscale = None + self.bucket_info = None # for metadata + + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + + self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ + + # augmentation + self.aug_helper = AugHelper() + + self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) + + self.image_data: Dict[str, ImageInfo] = {} + self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} + + self.replacements = {} + + def set_current_epoch(self, epoch): + self.current_epoch = epoch + self.shuffle_buckets() + + def set_tag_frequency(self, dir_name, captions): + frequency_for_dir = self.tag_frequency.get(dir_name, {}) + self.tag_frequency[dir_name] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + tag = tag.strip() + if tag: + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 + + def disable_token_padding(self): + self.token_padding_disabled = True + + def add_replacement(self, str_from, str_to): + self.replacements[str_from] = str_to + + def process_caption(self, subset: BaseSubset, caption): + # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い + is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate + is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 + + if is_drop_out: + caption = "" + else: + if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: + def dropout_tags(tokens): + if subset.caption_tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= subset.caption_tag_dropout_rate: + l.append(token) + return l + + fixed_tokens = [] + flex_tokens = [t.strip() for t in caption.strip().split(",")] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[:subset.keep_tokens] + flex_tokens = flex_tokens[subset.keep_tokens:] + + if subset.shuffle_caption: + random.shuffle(flex_tokens) + + flex_tokens = dropout_tags(flex_tokens) + + caption = ", ".join(fixed_tokens + flex_tokens) + + # textual inversion対応 + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) + else: + caption = str_to + else: + caption = caption.replace(str_from, str_to) + + return caption + + def get_input_ids(self, caption): + input_ids = self.tokenizer(caption, padding="max_length", truncation=True, + max_length=self.tokenizer_max_length, return_tensors="pt").input_ids + + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = (input_ids[0].unsqueeze(0), + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): + ids_chunk = (input_ids[0].unsqueeze(0), # BOS + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo, subset: BaseSubset): + self.image_data[info.image_key] = info + self.image_to_subset[info.image_key] = subset + + def make_buckets(self): + ''' + bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) + min_size and max_size are ignored when enable_bucket is False + ''' + print("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if self.enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + # bucketを作成し、画像をbucketに振り分ける + if self.enable_bucket: + if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み + self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height), + self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps) + if not self.bucket_no_upscale: + self.bucket_manager.make_buckets() + else: + print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます") + + img_ar_errors = [] + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height) + + # print(image_info.image_key, image_info.bucket_reso) + img_ar_errors.append(abs(ar_error)) + + self.bucket_manager.sort() + else: + self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) + self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) + + for image_info in self.image_data.values(): + for _ in range(image_info.num_repeats): + self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) + + # bucket情報を表示、格納する + if self.enable_bucket: + self.bucket_info = {"buckets": {}} + print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): + count = len(bucket) + if count > 0: + self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} + print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + + img_ar_errors = np.array(img_ar_errors) + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + self.bucket_info["mean_img_ar_error"] = mean_img_ar_error + print(f"mean ar error (without repeats): {mean_img_ar_error}") + + # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる + self.buckets_indices: List(BucketBatchIndex) = [] + for bucket_index, bucket in enumerate(self.bucket_manager.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) + + # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す + #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる + # + # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは + # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう + # # そのためバッチサイズを画像種類までに制限する + # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? + # # TO DO 正則化画像をepochまたがりで利用する仕組み + # num_of_image_types = len(set(bucket)) + # bucket_batch_size = min(self.batch_size, num_of_image_types) + # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) + # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # for batch_index in range(batch_count): + # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) + # ↑ここまで + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + random.shuffle(self.buckets_indices) + self.bucket_manager.shuffle() + + def load_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): + image_height, image_width = image.shape[0:2] + + if image_width != resized_size[0] or image_height != resized_size[1]: + # リサイズする + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + + image_height, image_width = image.shape[0:2] + if image_width > reso[0]: + trim_size = image_width - reso[0] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("w", trim_size, p) + image = image[:, p:p + reso[0]] + if image_height > reso[1]: + trim_size = image_height - reso[1] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("h", trim_size, p) + image = image[p:p + reso[1]] + + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def is_latent_cacheable(self): + return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + + def cache_latents(self, vae): + # TODO ここを高速化したい + print("caching latents.") + for info in tqdm(self.image_data.values()): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: + info.latents = self.load_latents_from_npz(info, False) + info.latents = torch.FloatTensor(info.latents) + info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + if info.latents_flipped is not None: + info.latents_flipped = torch.FloatTensor(info.latents_flipped) + continue + + image = self.load_image(info.absolute_path) + image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) + + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + if subset.flip_aug: + image = image[:, ::-1].copy() # cannot convert to Tensor without copy + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + def get_image_size(self, image_path): + image = Image.open(image_path) + return image.size + + def load_image_with_face_info(self, subset: BaseSubset, image_path: str): + img = self.load_image(image_path) + + face_cx = face_cy = face_w = face_h = 0 + if subset.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サイズはsizeより大きいのでリサイズする + face_size = max(face_w, face_h) + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) + min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + .5) + nw = int(width * scale + .5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + .5) + face_cy = int(face_cy * scale + .5) + height, width = nh, nw + + # 顔を中心として448*640とかへ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 + + if subset.random_crop: + # 背景も含めるために顔を中心に置く確率を高めつつずらす + range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 + else: + # range指定があるときのみ、すこしだけランダムに(わりと適当) + if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: + if face_size > self.size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1:p1 + target_size, :] + else: + image = image[:, p1:p1 + target_size] + + return image + + def load_latents_from_npz(self, image_info: ImageInfo, flipped): + npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + if npz_file is None: + return None + return np.load(npz_file)['arr_0'] + + def __len__(self): + return self._length + + def __getitem__(self, index): + bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] + bucket_batch_size = self.buckets_indices[index].bucket_batch_size + image_index = self.buckets_indices[index].batch_index * bucket_batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + + for image_key in bucket[image_index:image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + # image/latentsを処理する + if image_info.latents is not None: + latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: + latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み込み、必要ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p:p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p:p + self.width] + + im_h, im_w = img.shape[0:2] + assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # augmentation - flip_p = 0.5 if flip_aug else 0.0 - if color_aug: - # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る - self.aug = albu.Compose( - [ - albu.OneOf( - [ - albu.HueSaturationValue(8, 0, 0, p=0.5), - albu.RandomGamma((95, 105), p=0.5), - ], - p=0.33, - ), - albu.HorizontalFlip(p=flip_p), - ], - p=1.0, - ) - elif flip_aug: - self.aug = albu.Compose([albu.HorizontalFlip(p=flip_p)], p=1.0) - else: - self.aug = None - - self.image_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - self.image_data: Dict[str, ImageInfo] = {} - - self.replacements = {} - - def set_current_epoch(self, epoch): - self.current_epoch = epoch - - def set_caption_dropout( - self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate - ): - # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく) - self.dropout_rate = dropout_rate - self.dropout_every_n_epochs = dropout_every_n_epochs - self.tag_dropout_rate = tag_dropout_rate - - def set_tag_frequency(self, dir_name, captions): - frequency_for_dir = self.tag_frequency.get(dir_name, {}) - self.tag_frequency[dir_name] = frequency_for_dir - for caption in captions: - for tag in caption.split(','): - if tag and not tag.isspace(): - tag = tag.lower() - frequency = frequency_for_dir.get(tag, 0) - frequency_for_dir[tag] = frequency + 1 - - def disable_token_padding(self): - self.token_padding_disabled = True - - def add_replacement(self, str_from, str_to): - self.replacements[str_from] = str_to - - def process_caption(self, caption): - # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い - is_drop_out = ( - self.dropout_rate > 0 and random.random() < self.dropout_rate - ) - is_drop_out = ( - is_drop_out - or self.dropout_every_n_epochs - and self.current_epoch % self.dropout_every_n_epochs == 0 - ) - - if is_drop_out: - caption = '' - else: - if self.shuffle_caption or self.tag_dropout_rate > 0: - - def dropout_tags(tokens): - if self.tag_dropout_rate <= 0: - return tokens - l = [] - for token in tokens: - if random.random() >= self.tag_dropout_rate: - l.append(token) - return l - - tokens = [t.strip() for t in caption.strip().split(',')] - if self.shuffle_keep_tokens is None: - if self.shuffle_caption: - random.shuffle(tokens) - - tokens = dropout_tags(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[: self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens :] - - if self.shuffle_caption: - random.shuffle(tokens) - - tokens = dropout_tags(tokens) - - tokens = keep_tokens + tokens - caption = ', '.join(tokens) - - # textual inversion対応 - for str_from, str_to in self.replacements.items(): - if str_from == '': - # replace all - if type(str_to) == list: - caption = random.choice(str_to) - else: - caption = str_to - else: - caption = caption.replace(str_from, str_to) - - return caption - - def get_input_ids(self, caption): - input_ids = self.tokenizer( - caption, - padding='max_length', - truncation=True, - max_length=self.tokenizer_max_length, - return_tensors='pt', - ).input_ids - - if self.tokenizer_max_length > self.tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range( - 1, - self.tokenizer_max_length - - self.tokenizer.model_max_length - + 2, - self.tokenizer.model_max_length - 2, - ): # (1, 152, 75) - ids_chunk = ( - input_ids[0].unsqueeze(0), - input_ids[i : i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0), - ) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range( - 1, - self.tokenizer_max_length - - self.tokenizer.model_max_length - + 2, - self.tokenizer.model_max_length - 2, - ): - ids_chunk = ( - input_ids[0].unsqueeze(0), # BOS - input_ids[i : i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0), - ) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) - - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ( - ids_chunk[-2] != self.tokenizer.eos_token_id - and ids_chunk[-2] != self.tokenizer.pad_token_id - ): - ids_chunk[-1] = self.tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id - - iids_list.append(ids_chunk) - - input_ids = torch.stack(iids_list) # 3,77 - return input_ids - - def register_image(self, info: ImageInfo): - self.image_data[info.image_key] = info - - def make_buckets(self): - """ - bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - min_size and max_size are ignored when enable_bucket is False - """ - print('loading image sizes.') - for info in tqdm(self.image_data.values()): - if info.image_size is None: - info.image_size = self.get_image_size(info.absolute_path) - - if self.enable_bucket: - print('make buckets') - else: - print('prepare dataset') - - # bucketを作成し、画像をbucketに振り分ける - if self.enable_bucket: - if ( - self.bucket_manager is None - ): # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み - self.bucket_manager = BucketManager( - self.bucket_no_upscale, - (self.width, self.height), - self.min_bucket_reso, - self.max_bucket_reso, - self.bucket_reso_steps, - ) - if not self.bucket_no_upscale: - self.bucket_manager.make_buckets() - else: - print( - 'min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます' - ) - - img_ar_errors = [] - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - ( - image_info.bucket_reso, - image_info.resized_size, - ar_error, - ) = self.bucket_manager.select_bucket( - image_width, image_height - ) - - # print(image_info.image_key, image_info.bucket_reso) - img_ar_errors.append(abs(ar_error)) - - self.bucket_manager.sort() - else: - self.bucket_manager = BucketManager( - False, (self.width, self.height), None, None, None - ) - self.bucket_manager.set_predefined_resos( - [(self.width, self.height)] - ) # ひとつの固定サイズbucketのみ - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - ( - image_info.bucket_reso, - image_info.resized_size, - _, - ) = self.bucket_manager.select_bucket( - image_width, image_height - ) - - for image_info in self.image_data.values(): - for _ in range(image_info.num_repeats): - self.bucket_manager.add_image( - image_info.bucket_reso, image_info.image_key - ) - - # bucket情報を表示、格納する - if self.enable_bucket: - self.bucket_info = {'buckets': {}} - print( - 'number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)' - ) - for i, (reso, bucket) in enumerate( - zip(self.bucket_manager.resos, self.bucket_manager.buckets) - ): - count = len(bucket) - if count > 0: - self.bucket_info['buckets'][i] = { - 'resolution': reso, - 'count': len(bucket), - } - print( - f'bucket {i}: resolution {reso}, count: {len(bucket)}' - ) - - img_ar_errors = np.array(img_ar_errors) - mean_img_ar_error = np.mean(np.abs(img_ar_errors)) - self.bucket_info['mean_img_ar_error'] = mean_img_ar_error - print(f'mean ar error (without repeats): {mean_img_ar_error}') - - # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] - for bucket_index, bucket in enumerate(self.bucket_manager.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append( - BucketBatchIndex( - bucket_index, self.batch_size, batch_index - ) - ) - - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - self.bucket_manager.shuffle() - - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == 'RGB': - image = image.convert('RGB') - img = np.array(image, np.uint8) - return img - - def trim_and_resize_if_required(self, image, reso, resized_size): - image_height, image_width = image.shape[0:2] - - if image_width != resized_size[0] or image_height != resized_size[1]: - # リサイズする - image = cv2.resize( - image, resized_size, interpolation=cv2.INTER_AREA - ) # INTER_AREAでやりたいのでcv2でリサイズ - - image_height, image_width = image.shape[0:2] - if image_width > reso[0]: - trim_size = image_width - reso[0] - p = ( - trim_size // 2 - if not self.random_crop - else random.randint(0, trim_size) - ) - # print("w", trim_size, p) - image = image[:, p : p + reso[0]] - if image_height > reso[1]: - trim_size = image_height - reso[1] - p = ( - trim_size // 2 - if not self.random_crop - else random.randint(0, trim_size) - ) - # print("h", trim_size, p) - image = image[p : p + reso[1]] - - assert ( - image.shape[0] == reso[1] and image.shape[1] == reso[0] - ), f'internal error, illegal trimmed size: {image.shape}, {reso}' - return image - - def cache_latents(self, vae): - # TODO ここを高速化したい - print('caching latents.') - for info in tqdm(self.image_data.values()): - if info.latents_npz is not None: - info.latents = self.load_latents_from_npz(info, False) - info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz( - info, True - ) # might be None - if info.latents_flipped is not None: - info.latents_flipped = torch.FloatTensor( - info.latents_flipped - ) - continue - - image = self.load_image(info.absolute_path) - image = self.trim_and_resize_if_required( - image, info.bucket_reso, info.resized_size - ) - - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to( - device=vae.device, dtype=vae.dtype - ) - info.latents = ( - vae.encode(img_tensor) - .latent_dist.sample() - .squeeze(0) - .to('cpu') - ) - - if self.flip_aug: - image = image[ - :, ::-1 - ].copy() # cannot convert to Tensor without copy - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to( - device=vae.device, dtype=vae.dtype - ) - info.latents_flipped = ( - vae.encode(img_tensor) - .latent_dist.sample() - .squeeze(0) - .to('cpu') - ) - - def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size - - def load_image_with_face_info(self, image_path: str): - img = self.load_image(image_path) - - face_cx = face_cy = face_w = face_h = 0 - if self.face_crop_aug_range is not None: - tokens = os.path.splitext(os.path.basename(image_path))[0].split( - '_' - ) - if len(tokens) >= 5: - face_cx = int(tokens[-4]) - face_cy = int(tokens[-3]) - face_w = int(tokens[-2]) - face_h = int(tokens[-1]) - - return img, face_cx, face_cy, face_w, face_h - - # いい感じに切り出す - def crop_target(self, image, face_cx, face_cy, face_w, face_h): - height, width = image.shape[0:2] - if height == self.height and width == self.width: - return image - - # 画像サイズはsizeより大きいのでリサイズする - face_size = max(face_w, face_h) - min_scale = max( - self.height / height, self.width / width - ) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min( - 1.0, - max( - min_scale, - self.size / (face_size * self.face_crop_aug_range[1]), - ), - ) # 指定した顔最小サイズ - max_scale = min( - 1.0, - max( - min_scale, - self.size / (face_size * self.face_crop_aug_range[0]), - ), - ) # 指定した顔最大サイズ - if min_scale >= max_scale: # range指定がmin==max - scale = min_scale - else: - scale = random.uniform(min_scale, max_scale) - - nh = int(height * scale + 0.5) - nw = int(width * scale + 0.5) - assert ( - nh >= self.height and nw >= self.width - ), f'internal error. small scale {scale}, {width}*{height}' - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) - face_cx = int(face_cx * scale + 0.5) - face_cy = int(face_cy * scale + 0.5) - height, width = nh, nw - - # 顔を中心として448*640とかへ切り出す - for axis, (target_size, length, face_p) in enumerate( - zip((self.height, self.width), (height, width), (face_cy, face_cx)) - ): - p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - - if self.random_crop: - # 背景も含めるために顔を中心に置く確率を高めつつずらす - range = max( - length - face_p, face_p - ) # 画像の端から顔中心までの距離の長いほう - p1 = ( - p1 - + (random.randint(0, range) + random.randint(0, range)) - - range - ) # -range ~ +range までのいい感じの乱数 - else: - # range指定があるときのみ、すこしだけランダムに(わりと適当) - if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: - if face_size > self.size // 10 and face_size >= 40: - p1 = p1 + random.randint( - -face_size // 20, +face_size // 20 - ) - - p1 = max(0, min(p1, length - target_size)) - - if axis == 0: - image = image[p1 : p1 + target_size, :] - else: - image = image[:, p1 : p1 + target_size] - - return image - - def load_latents_from_npz(self, image_info: ImageInfo, flipped): - npz_file = ( - image_info.latents_npz_flipped - if flipped - else image_info.latents_npz - ) - if npz_file is None: - return None - return np.load(npz_file)['arr_0'] - - def __len__(self): - return self._length - - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - - bucket = self.bucket_manager.buckets[ - self.buckets_indices[index].bucket_index - ] - bucket_batch_size = self.buckets_indices[index].bucket_batch_size - image_index = ( - self.buckets_indices[index].batch_index * bucket_batch_size - ) - - loss_weights = [] - captions = [] - input_ids_list = [] - latents_list = [] - images = [] - - for image_key in bucket[image_index : image_index + bucket_batch_size]: - image_info = self.image_data[image_key] - loss_weights.append( - self.prior_loss_weight if image_info.is_reg else 1.0 - ) - - # image/latentsを処理する - if image_info.latents is not None: - latents = ( - image_info.latents - if not self.flip_aug or random.random() < 0.5 - else image_info.latents_flipped - ) - image = None - elif image_info.latents_npz is not None: - latents = self.load_latents_from_npz( - image_info, self.flip_aug and random.random() >= 0.5 - ) - latents = torch.FloatTensor(latents) - image = None - else: - # 画像を読み込み、必要ならcropする - ( - img, - face_cx, - face_cy, - face_w, - face_h, - ) = self.load_image_with_face_info(image_info.absolute_path) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img = self.trim_and_resize_if_required( - img, image_info.bucket_reso, image_info.resized_size - ) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target( - img, face_cx, face_cy, face_w, face_h - ) - elif im_h > self.height or im_w > self.width: - assert ( - self.random_crop - ), f'image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}' - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p : p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p : p + self.width] - - im_h, im_w = img.shape[0:2] - assert ( - im_h == self.height and im_w == self.width - ), f'image size is small / 画像サイズが小さいようです: {image_info.absolute_path}' - - # augmentation - if self.aug is not None: - img = self.aug(image=img)['image'] - - latents = None - image = self.image_transforms( - img - ) # -1.0~1.0のtorch.Tensorになる - - images.append(image) - latents_list.append(latents) - - caption = self.process_caption(image_info.caption) - captions.append(caption) - if ( - not self.token_padding_disabled - ): # this option might be omitted in future - input_ids_list.append(self.get_input_ids(caption)) - - example = {} - example['loss_weights'] = torch.FloatTensor(loss_weights) - - if self.token_padding_disabled: - # padding=True means pad in the batch - example['input_ids'] = self.tokenizer( - captions, padding=True, truncation=True, return_tensors='pt' - ).input_ids - else: - # batch processing seems to be good - example['input_ids'] = torch.stack(input_ids_list) - - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example['images'] = images - - example['latents'] = ( - torch.stack(latents_list) if latents_list[0] is not None else None - ) - - if self.debug_dataset: - example['image_keys'] = bucket[ - image_index : image_index + self.batch_size - ] - example['captions'] = captions - return example + aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + if aug is not None: + img = aug(image=img)['image'] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(subset, image_info.caption) + captions.append(caption) + if not self.token_padding_disabled: # this option might be omitted in future + input_ids_list.append(self.get_input_ids(caption)) + + example = {} + example['loss_weights'] = torch.FloatTensor(loss_weights) + + if self.token_padding_disabled: + # padding=True means pad in the batch + example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example['input_ids'] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example['images'] = images + + example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None + + if self.debug_dataset: + example['image_keys'] = bucket[image_index:image_index + self.batch_size] + example['captions'] = captions + return example class DreamBoothDataset(BaseDataset): - def __init__( - self, - batch_size, - train_data_dir, - reg_data_dir, - tokenizer, - max_token_length, - caption_extension, - shuffle_caption, - shuffle_keep_tokens, - resolution, - enable_bucket, - min_bucket_reso, - max_bucket_reso, - bucket_reso_steps, - bucket_no_upscale, - prior_loss_weight, - flip_aug, - color_aug, - face_crop_aug_range, - random_crop, - debug_dataset, - ) -> None: - super().__init__( - tokenizer, - max_token_length, - shuffle_caption, - shuffle_keep_tokens, - resolution, - flip_aug, - color_aug, - face_crop_aug_range, - random_crop, - debug_dataset, - ) + def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) - assert ( - resolution is not None - ), f'resolution is required / resolution(解像度)指定は必須です' + assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" - self.batch_size = batch_size - self.size = min(self.width, self.height) # 短いほう - self.prior_loss_weight = prior_loss_weight - self.latents_cache = None + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.prior_loss_weight = prior_loss_weight + self.latents_cache = None - self.enable_bucket = enable_bucket - if self.enable_bucket: - assert ( - min(resolution) >= min_bucket_reso - ), f'min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください' - assert ( - max(resolution) <= max_bucket_reso - ), f'max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください' - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - self.min_bucket_reso = None - self.max_bucket_reso = None - self.bucket_reso_steps = None # この情報は使われない - self.bucket_no_upscale = False + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None # この情報は使われない + self.bucket_no_upscale = False - def read_caption(img_path): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split('_') - if len(tokens) >= 5: - base_name_face_det = '_'.join(tokens[:-4]) - cap_paths = [ - base_name + caption_extension, - base_name_face_det + caption_extension, - ] + def read_caption(img_path, caption_extension): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, 'rt', encoding='utf-8') as f: - try: - lines = f.readlines() - except UnicodeDecodeError as e: - print( - f'illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}' - ) - raise e - assert ( - len(lines) > 0 - ), f'caption file is empty / キャプションファイルが空です: {cap_path}' - caption = lines[0].strip() - break - return caption - - def load_dreambooth_dir(dir): - if not os.path.isdir(dir): - # print(f"ignore file: {dir}") - return 0, [], [] - - tokens = os.path.basename(dir).split('_') + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding='utf-8') as f: try: - n_repeats = int(tokens[0]) - except ValueError as e: - print( - f'ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}' - ) - return 0, [], [] + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption - caption_by_folder = '_'.join(tokens[1:]) - img_paths = glob_images(dir, '*') - print( - f'found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files' - ) + def load_dreambooth_dir(subset: DreamBoothSubset): + if not os.path.isdir(subset.image_dir): + print(f"not directory: {subset.image_dir}") + return [], [] - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path) - captions.append( - caption_by_folder if cap_for_img is None else cap_for_img - ) + img_paths = glob_images(subset.image_dir, "*") + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - self.set_tag_frequency( - os.path.basename(dir), captions - ) # タグ頻度を記録 + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension) + if cap_for_img is None and subset.class_tokens is None: + print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") + captions.append("") + else: + captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) + + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 - return n_repeats, img_paths, captions + return img_paths, captions - print('prepare train images.') - train_dirs = os.listdir(train_data_dir) - num_train_images = 0 - for dir in train_dirs: - n_repeats, img_paths, captions = load_dreambooth_dir( - os.path.join(train_data_dir, dir) - ) - num_train_images += n_repeats * len(img_paths) + print("prepare images.") + num_train_images = 0 + num_reg_images = 0 + reg_infos: List[ImageInfo] = [] + for subset in subsets: + if subset.num_repeats < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") + continue - for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, n_repeats, caption, False, img_path) - self.register_image(info) + if subset in self.subsets: + print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") + continue - self.dataset_dirs_info[os.path.basename(dir)] = { - 'n_repeats': n_repeats, - 'img_count': len(img_paths), - } + img_paths, captions = load_dreambooth_dir(subset) + if len(img_paths) < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + continue - print(f'{num_train_images} train images with repeating.') - self.num_train_images = num_train_images + if subset.is_reg: + num_reg_images += subset.num_repeats * len(img_paths) + else: + num_train_images += subset.num_repeats * len(img_paths) - # reg imageは数を数えて学習画像と同じ枚数にする - num_reg_images = 0 - if reg_data_dir: - print('prepare reg images.') - reg_infos: List[ImageInfo] = [] + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if subset.is_reg: + reg_infos.append(info) + else: + self.register_image(info, subset) - reg_dirs = os.listdir(reg_data_dir) - for dir in reg_dirs: - n_repeats, img_paths, captions = load_dreambooth_dir( - os.path.join(reg_data_dir, dir) - ) - num_reg_images += n_repeats * len(img_paths) + subset.img_count = len(img_paths) + self.subsets.append(subset) - for img_path, caption in zip(img_paths, captions): - info = ImageInfo( - img_path, n_repeats, caption, True, img_path - ) - reg_infos.append(info) + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images - self.reg_dataset_dirs_info[os.path.basename(dir)] = { - 'n_repeats': n_repeats, - 'img_count': len(img_paths), - } + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - print(f'{num_reg_images} reg images.') - if num_train_images < num_reg_images: - print( - 'some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります' - ) + if num_reg_images == 0: + print("no regularization images / 正則化画像が見つかりませんでした") + else: + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する + n = 0 + first_loop = True + while n < num_train_images: + for info in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 + n += 1 + if n >= num_train_images: + break + first_loop = False - if num_reg_images == 0: - print('no regularization images / 正則化画像が見つかりませんでした') - else: - # num_repeatsを計算する:どうせ大した数ではないのでループで処理する - n = 0 - first_loop = True - while n < num_train_images: - for info in reg_infos: - if first_loop: - self.register_image(info) - n += info.num_repeats - else: - info.num_repeats += 1 - n += 1 - if n >= num_train_images: - break - first_loop = False - - self.num_reg_images = num_reg_images + self.num_reg_images = num_reg_images class FineTuningDataset(BaseDataset): - def __init__( - self, - json_file_name, - batch_size, - train_data_dir, - tokenizer, - max_token_length, - shuffle_caption, - shuffle_keep_tokens, - resolution, - enable_bucket, - min_bucket_reso, - max_bucket_reso, - bucket_reso_steps, - bucket_no_upscale, - flip_aug, - color_aug, - face_crop_aug_range, - random_crop, - dataset_repeats, - debug_dataset, - ) -> None: - super().__init__( - tokenizer, - max_token_length, - shuffle_caption, - shuffle_keep_tokens, - resolution, - flip_aug, - color_aug, - face_crop_aug_range, - random_crop, - debug_dataset, - ) + def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) - # メタデータを読み込む - if os.path.exists(json_file_name): - print(f'loading existing metadata: {json_file_name}') - with open(json_file_name, 'rt', encoding='utf-8') as f: - metadata = json.load(f) + self.batch_size = batch_size + + self.num_train_images = 0 + self.num_reg_images = 0 + + for subset in subsets: + if subset.num_repeats < 1: + print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") + continue + + if subset in self.subsets: + print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") + continue + + # メタデータを読み込む + if os.path.exists(subset.metadata_file): + print(f"loading existing metadata: {subset.metadata_file}") + with open(subset.metadata_file, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") + + if len(metadata) < 1: + print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + continue + + tags_list = [] + for image_key, img_md in metadata.items(): + # path情報を作る + if os.path.exists(image_key): + abs_path = image_key else: - raise ValueError( - f'no metadata / メタデータファイルがありません: {json_file_name}' - ) + # わりといい加減だがいい方法が思いつかん + abs_path = glob_images(subset.image_dir, image_key) + assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" + abs_path = abs_path[0] - self.metadata = metadata - self.train_data_dir = train_data_dir - self.batch_size = batch_size + caption = img_md.get('caption') + tags = img_md.get('tags') + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ', ' + tags + tags_list.append(tags) + assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" - tags_list = [] - for image_key, img_md in metadata.items(): - # path情報を作る - if os.path.exists(image_key): - abs_path = image_key - else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(train_data_dir, image_key) - assert len(abs_path) >= 1, f'no image / 画像がありません: {image_key}' - abs_path = abs_path[0] + image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) + image_info.image_size = img_md.get('train_resolution') - caption = img_md.get('caption') - tags = img_md.get('tags') - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - tags_list.append(tags) - assert ( - caption is not None and len(caption) > 0 - ), f'caption or tag is required / キャプションまたはタグは必須です:{abs_path}' + if not subset.color_aug and not subset.random_crop: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) - image_info = ImageInfo( - image_key, dataset_repeats, caption, False, abs_path - ) - image_info.image_size = img_md.get('train_resolution') + self.register_image(image_info, subset) - if not self.color_aug and not self.random_crop: - # if npz exists, use them - ( - image_info.latents_npz, - image_info.latents_npz_flipped, - ) = self.image_key_to_npz_file(image_key) + self.num_train_images += len(metadata) * subset.num_repeats - self.register_image(image_info) - self.num_train_images = len(metadata) * dataset_repeats - self.num_reg_images = 0 + # TODO do not record tag freq when no tag + self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) + subset.img_count = len(metadata) + self.subsets.append(subset) - # TODO do not record tag freq when no tag - self.set_tag_frequency(os.path.basename(json_file_name), tags_list) - self.dataset_dirs_info[os.path.basename(json_file_name)] = { - 'n_repeats': dataset_repeats, - 'img_count': len(metadata), - } + # check existence of all npz files + use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets]) + if use_npz_latents: + flip_aug_in_subset = False + npz_any = False + npz_all = True - # check existence of all npz files - use_npz_latents = not (self.color_aug or self.random_crop) - if use_npz_latents: - npz_any = False - npz_all = True - for image_info in self.image_data.values(): - has_npz = image_info.latents_npz is not None - npz_any = npz_any or has_npz + for image_info in self.image_data.values(): + subset = self.image_to_subset[image_info.image_key] - if self.flip_aug: - has_npz = ( - has_npz and image_info.latents_npz_flipped is not None - ) - npz_all = npz_all and has_npz + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz - if npz_any and not npz_all: - break + if subset.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + flip_aug_in_subset = True + npz_all = npz_all and has_npz - if not npz_any: - use_npz_latents = False - print( - f'npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します' - ) - elif not npz_all: - use_npz_latents = False - print( - f'some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します' - ) - if self.flip_aug: - print('maybe no flipped files / 反転されたnpzファイルがないのかもしれません') - # else: - # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") + if npz_any and not npz_all: + break - # check min/max bucket size - sizes = set() - resos = set() - for image_info in self.image_data.values(): - if image_info.image_size is None: - sizes = None # not calculated - break - sizes.add(image_info.image_size[0]) - sizes.add(image_info.image_size[1]) - resos.add(tuple(image_info.image_size)) + if not npz_any: + use_npz_latents = False + print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + elif not npz_all: + use_npz_latents = False + print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + if flip_aug_in_subset: + print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + # else: + # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") - if sizes is None: - if use_npz_latents: - use_npz_latents = False - print( - f'npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します' - ) + # check min/max bucket size + sizes = set() + resos = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + resos.add(tuple(image_info.image_size)) - assert ( - resolution is not None - ), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" + if sizes is None: + if use_npz_latents: + use_npz_latents = False + print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") - self.enable_bucket = enable_bucket - if self.enable_bucket: - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - if not enable_bucket: - print( - 'metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします' - ) - print('using bucket info in metadata / メタデータ内のbucket情報を使います') - self.enable_bucket = True + assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" - assert ( - not bucket_no_upscale - ), 'if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません' + self.enable_bucket = enable_bucket + if self.enable_bucket: + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + if not enable_bucket: + print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + print("using bucket info in metadata / メタデータ内のbucket情報を使います") + self.enable_bucket = True - # bucket情報を初期化しておく、make_bucketsで再作成しない - self.bucket_manager = BucketManager(False, None, None, None, None) - self.bucket_manager.set_predefined_resos(resos) + assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません" - # npz情報をきれいにしておく - if not use_npz_latents: - for image_info in self.image_data.values(): - image_info.latents_npz = image_info.latents_npz_flipped = None + # bucket情報を初期化しておく、make_bucketsで再作成しない + self.bucket_manager = BucketManager(False, None, None, None, None) + self.bucket_manager.set_predefined_resos(resos) - def image_key_to_npz_file(self, image_key): - base_name = os.path.splitext(image_key)[0] - npz_file_norm = base_name + '.npz' + # npz情報をきれいにしておく + if not use_npz_latents: + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None - if os.path.exists(npz_file_norm): - # image_key is full path - npz_file_flip = base_name + '_flip.npz' - if not os.path.exists(npz_file_flip): - npz_file_flip = None - return npz_file_norm, npz_file_flip + def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + '.npz' - # image_key is relative path - npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') - npz_file_flip = os.path.join( - self.train_data_dir, image_key + '_flip.npz' - ) + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + '_flip.npz' + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip - if not os.path.exists(npz_file_norm): - npz_file_norm = None - npz_file_flip = None - elif not os.path.exists(npz_file_flip): - npz_file_flip = None + # image_key is relative path + npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz') + npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz') - return npz_file_norm, npz_file_flip + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None + + return npz_file_norm, npz_file_flip + + +# behave as Dataset mock +class DatasetGroup(torch.utils.data.ConcatDataset): + def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): + self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] + + super().__init__(datasets) + + self.image_data = {} + self.num_train_images = 0 + self.num_reg_images = 0 + + # simply concat together + # TODO: handling image_data key duplication among dataset + # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. + for dataset in datasets: + self.image_data.update(dataset.image_data) + self.num_train_images += dataset.num_train_images + self.num_reg_images += dataset.num_reg_images + + def add_replacement(self, str_from, str_to): + for dataset in self.datasets: + dataset.add_replacement(str_from, str_to) + + # def make_buckets(self): + # for dataset in self.datasets: + # dataset.make_buckets() + + def cache_latents(self, vae): + for i, dataset in enumerate(self.datasets): + print(f"[Dataset {i}]") + dataset.cache_latents(vae) + + def is_latent_cacheable(self) -> bool: + return all([dataset.is_latent_cacheable() for dataset in self.datasets]) + + def set_current_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_current_epoch(epoch) + + def disable_token_padding(self): + for dataset in self.datasets: + dataset.disable_token_padding() def debug_dataset(train_dataset, show_input_ids=False): - print( - f'Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}' - ) - print('Escape for exit. / Escキーで中断、終了します') + print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + print("Escape for exit. / Escキーで中断、終了します") - train_dataset.set_current_epoch(1) - k = 0 - for i, example in enumerate(train_dataset): - if example['latents'] is not None: - print( - f"sample has latents from npz file: {example['latents'].size()}" - ) - for j, (ik, cap, lw, iid) in enumerate( - zip( - example['image_keys'], - example['captions'], - example['loss_weights'], - example['input_ids'], - ) - ): - print( - f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"' - ) - if show_input_ids: - print(f'input ids: {iid}') - if example['images'] is not None: - im = example['images'][j] - print(f'image size: {im.size()}') - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose( - im, (1, 2, 0) - ) # c,H,W -> H,W,c - im = im[ - :, :, ::-1 - ] # RGB -> BGR (OpenCV) - if os.name == 'nt': # only windows - cv2.imshow('img', im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == 27 or (example['images'] is None and i >= 8): - break + train_dataset.set_current_epoch(1) + k = 0 + indices = list(range(len(train_dataset))) + random.shuffle(indices) + for i, idx in enumerate(indices): + example = train_dataset[idx] + if example['latents'] is not None: + print(f"sample has latents from npz file: {example['latents'].size()}") + for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + if show_input_ids: + print(f"input ids: {iid}") + if example['images'] is not None: + im = example['images'][j] + print(f"image size: {im.size()}") + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + if os.name == 'nt': # only windows + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27: + break + if k == 27 or (example['images'] is None and i >= 8): + break -def glob_images(directory, base='*'): - img_paths = [] - for ext in IMAGE_EXTENSIONS: - if base == '*': - img_paths.extend( - glob.glob(os.path.join(glob.escape(directory), base + ext)) - ) - else: - img_paths.extend( - glob.glob(glob.escape(os.path.join(directory, base + ext))) - ) - # img_paths = list(set(img_paths)) # 重複を排除 - # img_paths.sort() - return img_paths +def glob_images(directory, base="*"): + img_paths = [] + for ext in IMAGE_EXTENSIONS: + if base == '*': + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) + else: + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + # img_paths = list(set(img_paths)) # 重複を排除 + # img_paths.sort() + return img_paths def glob_images_pathlib(dir_path, recursive): - image_paths = [] - if recursive: - for ext in IMAGE_EXTENSIONS: - image_paths += list(dir_path.rglob('*' + ext)) - else: - for ext in IMAGE_EXTENSIONS: - image_paths += list(dir_path.glob('*' + ext)) - # image_paths = list(set(image_paths)) # 重複を排除 - # image_paths.sort() - return image_paths - + image_paths = [] + if recursive: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.rglob('*' + ext)) + else: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.glob('*' + ext)) + # image_paths = list(set(image_paths)) # 重複を排除 + # image_paths.sort() + return image_paths # endregion @@ -1416,92 +1153,86 @@ EPSILON = 1e-6 def exists(val): - return val is not None + return val is not None def default(val, d): - return val if exists(val) else d + return val if exists(val) else d def model_hash(filename): - """Old model hash used by stable-diffusion-webui""" - try: - with open(filename, 'rb') as file: - m = hashlib.sha256() + """Old model hash used by stable-diffusion-webui""" + try: + with open(filename, "rb") as file: + m = hashlib.sha256() - file.seek(0x100000) - m.update(file.read(0x10000)) - return m.hexdigest()[0:8] - except FileNotFoundError: - return 'NOFILE' + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return 'NOFILE' def calculate_sha256(filename): - """New model hash used by stable-diffusion-webui""" - hash_sha256 = hashlib.sha256() - blksize = 1024 * 1024 + """New model hash used by stable-diffusion-webui""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 - with open(filename, 'rb') as f: - for chunk in iter(lambda: f.read(blksize), b''): - hash_sha256.update(chunk) + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) - return hash_sha256.hexdigest() + return hash_sha256.hexdigest() def precalculate_safetensors_hashes(tensors, metadata): - """Precalculate the model hashes needed by sd-webui-additional-networks to - save time on indexing the model later.""" + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" - # Because writing user metadata to the file can change the result of - # sd_models.model_hash(), only retain the training metadata for purposes of - # calculating the hash, as they are meant to be immutable - metadata = {k: v for k, v in metadata.items() if k.startswith('ss_')} + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} - bytes = safetensors.torch.save(tensors, metadata) - b = BytesIO(bytes) + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) - model_hash = addnet_hash_safetensors(b) - legacy_hash = addnet_hash_legacy(b) - return model_hash, legacy_hash + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash def addnet_hash_legacy(b): - """Old model hash used by sd-webui-additional-networks for .safetensors format files""" - m = hashlib.sha256() + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() - b.seek(0x100000) - m.update(b.read(0x10000)) - return m.hexdigest()[0:8] + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] def addnet_hash_safetensors(b): - """New model hash used by sd-webui-additional-networks for .safetensors format files""" - hash_sha256 = hashlib.sha256() - blksize = 1024 * 1024 + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 - b.seek(0) - header = b.read(8) - n = int.from_bytes(header, 'little') + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") - offset = n + 8 - b.seek(offset) - for chunk in iter(lambda: b.read(blksize), b''): - hash_sha256.update(chunk) + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) - return hash_sha256.hexdigest() + return hash_sha256.hexdigest() def get_git_revision_hash() -> str: - try: - return ( - subprocess.check_output( - ['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__) - ) - .decode('ascii') - .strip() - ) - except: - return '(unknown)' + try: + return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip() + except: + return "(unknown)" # flash attention forwards and backwards @@ -1510,744 +1241,422 @@ def get_git_revision_hash() -> str: class FlashAttentionFunction(torch.autograd.function.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" + @ staticmethod + @ torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """ Algorithm 2 in the paper """ - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - o = torch.zeros_like(q) - all_row_sums = torch.zeros( - (*q.shape[:-1], 1), dtype=dtype, device=device - ) - all_row_maxes = torch.full( - (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device - ) + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - scale = q.shape[-1] ** -0.5 + scale = (q.shape[-1] ** -0.5) - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, 'b n -> b 1 1 n') + mask = mask.split(q_bucket_size, dim=-1) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate( - row_splits - ): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = ( - einsum('... i d, ... j d -> ... i j', qc, kc) * scale - ) + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) - if causal and q_start_index < ( - k_start_index + k_bucket_size - 1 - ): - causal_mask = torch.ones( - (qc.shape[-2], kc.shape[-2]), - dtype=torch.bool, - device=device, - ).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.) - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( - min=EPSILON - ) + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - exp_values = einsum( - '... i j, ... j d -> ... i d', exp_weights, vc - ) + exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp( - block_row_maxes - new_row_maxes - ) + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - new_row_sums = ( - exp_row_max_diff * row_sums - + exp_block_row_max_diff * block_row_sums - ) + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( - (exp_block_row_max_diff / new_row_sums) * exp_values - ) + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - return o + return o - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" + @ staticmethod + @ torch.no_grad() + def backward(ctx, do): + """ Algorithm 4 in the paper """ - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors - device = q.device + device = q.device - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2) + ) - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = ( - einsum('... i d, ... j d -> ... i j', qc, kc) * scale - ) + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - if causal and q_start_index < ( - k_start_index + k_bucket_size - 1 - ): - causal_mask = torch.ones( - (qc.shape[-2], kc.shape[-2]), - dtype=torch.bool, - device=device, - ).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) - exp_attn_weights = torch.exp(attn_weights - mc) + exp_attn_weights = torch.exp(attn_weights - mc) - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.) - p = exp_attn_weights / lc + p = exp_attn_weights / lc - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) + dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) + dp = einsum('... i d, ... j d -> ... i j', doc, vc) - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) + dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, None, None, None -def replace_unet_modules( - unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, - mem_eff_attn, - xformers, -): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() def replace_unet_cross_attn_to_memory_efficient(): - print( - 'Replace CrossAttention.forward to use FlashAttention (not xformers)' - ) - flash_func = FlashAttentionFunction + print("Replace CrossAttention.forward to use FlashAttention (not xformers)") + flash_func = FlashAttentionFunction - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 - h = self.heads - q = self.to_q(x) + h = self.heads + q = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) + context = context if context is not None else x + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x - q, k, v = map( - lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v) - ) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - out = flash_func.apply( - q, k, v, mask, False, q_bucket_size, k_bucket_size - ) + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, 'b h n d -> b n (h d)') - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) + out = self.to_out[0](out) + out = self.to_out[1](out) + return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + diffusers.models.attention.CrossAttention.forward = forward_flash_attn def replace_unet_cross_attn_to_xformers(): - print('Replace CrossAttention.forward to use xformers') - try: - import xformers.ops - except ImportError: - raise ImportError('No xformers / xformersがインストールされていないようです') + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) - context = default(context, x) - context = context.to(x.dtype) + context = default(context, x) + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) - q, k, v = map( - lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), - (q_in, k_in, v_in), - ) - del q_in, k_in, v_in + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None - ) # 最適なのを選んでくれる + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - out = rearrange(out, 'b n h d -> b n (h d)', h=h) - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_xformers + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + diffusers.models.attention.CrossAttention.forward = forward_xformers # endregion # region arguments - def add_sd_models_arguments(parser: argparse.ArgumentParser): - # for pretrained models - parser.add_argument( - '--v2', - action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む', - ) - parser.add_argument( - '--v_parameterization', - action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする', - ) - parser.add_argument( - '--pretrained_model_name_or_path', - type=str, - default=None, - help='pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル', - ) + # for pretrained models + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') + parser.add_argument("--v_parameterization", action='store_true', + help='enable v-parameterization training / v-parameterization学習を有効にする') + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") + parser.add_argument("--tokenizer_cache_dir", type=str, default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)") def add_optimizer_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - '--optimizer_type', - type=str, - default='', - help='Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor', - ) + parser.add_argument("--optimizer_type", type=str, default="", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor") - # backward compatibility - parser.add_argument( - '--use_8bit_adam', - action='store_true', - help='use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)', - ) - parser.add_argument( - '--use_lion_optimizer', - action='store_true', - help='use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)', - ) + # backward compatibility + parser.add_argument("--use_8bit_adam", action="store_true", + help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") + parser.add_argument("--use_lion_optimizer", action="store_true", + help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)") - parser.add_argument( - '--learning_rate', - type=float, - default=2.0e-6, - help='learning rate / 学習率', - ) - parser.add_argument( - '--max_grad_norm', - default=1.0, - type=float, - help='Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない', - ) + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない") - parser.add_argument( - '--optimizer_args', - type=str, - default=None, - nargs='*', - help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', - ) + parser.add_argument("--optimizer_args", type=str, default=None, nargs='*', + help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")") - parser.add_argument( - '--lr_scheduler', - type=str, - default='constant', - help='scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor', - ) - parser.add_argument( - '--lr_warmup_steps', - type=int, - default=0, - help='Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)', - ) - parser.add_argument( - '--lr_scheduler_num_cycles', - type=int, - default=1, - help='Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数', - ) - parser.add_argument( - '--lr_scheduler_power', - type=float, - default=1, - help='Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power', - ) + parser.add_argument("--lr_scheduler", type=str, default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor") + parser.add_argument("--lr_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数") + parser.add_argument("--lr_scheduler_power", type=float, default=1, + help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power") -def add_training_arguments( - parser: argparse.ArgumentParser, support_dreambooth: bool -): - parser.add_argument( - '--output_dir', - type=str, - default=None, - help='directory to output trained model / 学習後のモデル出力先ディレクトリ', - ) - parser.add_argument( - '--output_name', - type=str, - default=None, - help='base name of trained model file / 学習後のモデルの拡張子を除くファイル名', - ) - parser.add_argument( - '--save_precision', - type=str, - default=None, - choices=[None, 'float', 'fp16', 'bf16'], - help='precision in saving / 保存時に精度を変更して保存する', - ) - parser.add_argument( - '--save_every_n_epochs', - type=int, - default=None, - help='save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する', - ) - parser.add_argument( - '--save_n_epoch_ratio', - type=int, - default=None, - help='save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)', - ) - parser.add_argument( - '--save_last_n_epochs', - type=int, - default=None, - help='save last N checkpoints / 最大Nエポック保存する', - ) - parser.add_argument( - '--save_last_n_epochs_state', - type=int, - default=None, - help='save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)', - ) - parser.add_argument( - '--save_state', - action='store_true', - help='save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する', - ) - parser.add_argument( - '--resume', - type=str, - default=None, - help='saved state to resume training / 学習再開するモデルのstate', - ) +def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): + parser.add_argument("--output_dir", type=str, default=None, + help="directory to output trained model / 学習後のモデル出力先ディレクトリ") + parser.add_argument("--output_name", type=str, default=None, + help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") + parser.add_argument("--save_every_n_epochs", type=int, default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_n_epoch_ratio", type=int, default=None, + help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)") + parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") + parser.add_argument("--save_last_n_epochs_state", type=int, default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") + parser.add_argument("--save_state", action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument( - '--train_batch_size', - type=int, - default=1, - help='batch size for training / 学習時のバッチサイズ', - ) - parser.add_argument( - '--max_token_length', - type=int, - default=None, - choices=[None, 150, 225], - help='max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)', - ) - parser.add_argument( - '--mem_eff_attn', - action='store_true', - help='use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う', - ) - parser.add_argument( - '--xformers', - action='store_true', - help='use xformers for CrossAttention / CrossAttentionにxformersを使う', - ) - parser.add_argument( - '--vae', - type=str, - default=None, - help='path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ', - ) + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") + parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") + parser.add_argument("--mem_eff_attn", action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") + parser.add_argument("--xformers", action="store_true", + help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument("--vae", type=str, default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") - parser.add_argument( - '--max_train_steps', - type=int, - default=1600, - help='training steps / 学習ステップ数', - ) - parser.add_argument( - '--max_train_epochs', - type=int, - default=None, - help='training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)', - ) - parser.add_argument( - '--max_data_loader_n_workers', - type=int, - default=8, - help='max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)', - ) - parser.add_argument( - '--persistent_data_loader_workers', - action='store_true', - help='persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)', - ) - parser.add_argument( - '--seed', - type=int, - default=None, - help='random seed for training / 学習時の乱数のseed', - ) - parser.add_argument( - '--gradient_checkpointing', - action='store_true', - help='enable gradient checkpointing / grandient checkpointingを有効にする', - ) - parser.add_argument( - '--gradient_accumulation_steps', - type=int, - default=1, - help='Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数', - ) - parser.add_argument( - '--mixed_precision', - type=str, - default='no', - choices=['no', 'fp16', 'bf16'], - help='use mixed precision / 混合精度を使う場合、その精度', - ) - parser.add_argument( - '--full_fp16', - action='store_true', - help='fp16 training including gradients / 勾配も含めてfp16で学習する', - ) - parser.add_argument( - '--clip_skip', - type=int, - default=None, - help='use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)', - ) - parser.add_argument( - '--logging_dir', - type=str, - default=None, - help='enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する', - ) - parser.add_argument( - '--log_prefix', - type=str, - default=None, - help='add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列', - ) - parser.add_argument( - '--noise_offset', - type=float, - default=None, - help='enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)', - ) - parser.add_argument( - '--lowram', - action='store_true', - help='enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)', - ) + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument("--max_train_epochs", type=int, default=None, + help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") + parser.add_argument("--max_data_loader_n_workers", type=int, default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") + parser.add_argument("--persistent_data_loader_workers", action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)") + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument("--gradient_checkpointing", action="store_true", + help="enable gradient checkpointing / grandient checkpointingを有効にする") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") + parser.add_argument("--mixed_precision", type=str, default="no", + choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument("--clip_skip", type=int, default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") + parser.add_argument("--logging_dir", type=str, default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument("--noise_offset", type=float, default=None, + help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") + parser.add_argument("--lowram", action="store_true", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)") - if support_dreambooth: - # DreamBooth training - parser.add_argument( - '--prior_loss_weight', - type=float, - default=1.0, - help='loss weight for regularization images / 正則化画像のlossの重み', - ) + parser.add_argument("--sample_every_n_steps", type=int, default=None, + help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する") + parser.add_argument("--sample_every_n_epochs", type=int, default=None, + help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)") + parser.add_argument("--sample_prompts", type=str, default=None, + help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル") + parser.add_argument('--sample_sampler', type=str, default='ddim', + choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver', + 'dpmsolver++', 'dpmsingle', + 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'], + help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類') + + if support_dreambooth: + # DreamBooth training + parser.add_argument("--prior_loss_weight", type=float, default=1.0, + help="loss weight for regularization images / 正則化画像のlossの重み") def verify_training_args(args: argparse.Namespace): - if args.v_parameterization and not args.v2: - print( - 'v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません' - ) - if args.v2 and args.clip_skip is not None: - print( - 'v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません' - ) + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") -def add_dataset_arguments( - parser: argparse.ArgumentParser, - support_dreambooth: bool, - support_caption: bool, - support_caption_dropout: bool, -): - # dataset common - parser.add_argument( - '--train_data_dir', - type=str, - default=None, - help='directory for train images / 学習画像データのディレクトリ', - ) - parser.add_argument( - '--shuffle_caption', - action='store_true', - help='shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする', - ) - parser.add_argument( - '--caption_extension', - type=str, - default='.caption', - help='extension of caption files / 読み込むcaptionファイルの拡張子', - ) - parser.add_argument( - '--caption_extention', - type=str, - default=None, - help='extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)', - ) - parser.add_argument( - '--keep_tokens', - type=int, - default=None, - help='keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す', - ) - parser.add_argument( - '--color_aug', - action='store_true', - help='enable weak color augmentation / 学習時に色合いのaugmentationを有効にする', - ) - parser.add_argument( - '--flip_aug', - action='store_true', - help='enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする', - ) - parser.add_argument( - '--face_crop_aug_range', - type=str, - default=None, - help='enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)', - ) - parser.add_argument( - '--random_crop', - action='store_true', - help='enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)', - ) - parser.add_argument( - '--debug_dataset', - action='store_true', - help='show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)', - ) - parser.add_argument( - '--resolution', - type=str, - default=None, - help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)", - ) - parser.add_argument( - '--cache_latents', - action='store_true', - help='cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)', - ) - parser.add_argument( - '--enable_bucket', - action='store_true', - help='enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする', - ) - parser.add_argument( - '--min_bucket_reso', - type=int, - default=256, - help='minimum resolution for buckets / bucketの最小解像度', - ) - parser.add_argument( - '--max_bucket_reso', - type=int, - default=1024, - help='maximum resolution for buckets / bucketの最大解像度', - ) - parser.add_argument( - '--bucket_reso_steps', - type=int, - default=64, - help='steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します', - ) - parser.add_argument( - '--bucket_no_upscale', - action='store_true', - help='make bucket for each image without upscaling / 画像を拡大せずbucketを作成します', - ) +def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool): + # dataset common + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--shuffle_caption", action="store_true", + help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") + parser.add_argument("--keep_tokens", type=int, default=0, + help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)") + parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") + parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument("--face_crop_aug_range", type=str, default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") + parser.add_argument("--random_crop", action="store_true", + help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") + parser.add_argument("--debug_dataset", action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") + parser.add_argument("--resolution", type=str, default=None, + help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") + parser.add_argument("--cache_latents", action="store_true", + help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") + parser.add_argument("--enable_bucket", action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument("--bucket_reso_steps", type=int, default=64, + help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します") + parser.add_argument("--bucket_no_upscale", action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") - if support_caption_dropout: - # Textual Inversion はcaptionのdropoutをsupportしない - # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに - parser.add_argument( - '--caption_dropout_rate', - type=float, - default=0, - help='Rate out dropout caption(0.0~1.0) / captionをdropoutする割合', - ) - parser.add_argument( - '--caption_dropout_every_n_epochs', - type=int, - default=None, - help='Dropout all captions every N epochs / captionを指定エポックごとにdropoutする', - ) - parser.add_argument( - '--caption_tag_dropout_rate', - type=float, - default=0, - help='Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合', - ) + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに + parser.add_argument("--caption_dropout_rate", type=float, default=0.0, + help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") + parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0, + help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") + parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0, + help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合") - if support_dreambooth: - # DreamBooth dataset - parser.add_argument( - '--reg_data_dir', - type=str, - default=None, - help='directory for regularization images / 正則化画像データのディレクトリ', - ) + if support_dreambooth: + # DreamBooth dataset + parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") - if support_caption: - # caption dataset - parser.add_argument( - '--in_json', - type=str, - default=None, - help='json metadata for dataset / データセットのmetadataのjsonファイル', - ) - parser.add_argument( - '--dataset_repeats', - type=int, - default=1, - help='repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数', - ) + if support_caption: + # caption dataset + parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") + parser.add_argument("--dataset_repeats", type=int, default=1, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") def add_sd_saving_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - '--save_model_as', - type=str, - default=None, - choices=[ - None, - 'ckpt', - 'safetensors', - 'diffusers', - 'diffusers_safetensors', - ], - help='format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)', - ) - parser.add_argument( - '--use_safetensors', - action='store_true', - help='use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)', - ) - + parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") + parser.add_argument("--use_safetensors", action='store_true', + help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") # endregion @@ -2255,230 +1664,171 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" - optimizer_type = args.optimizer_type - if args.use_8bit_adam: - assert ( - not args.use_lion_optimizer - ), 'both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています' - assert ( - optimizer_type is None or optimizer_type == '' - ), 'both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています' - optimizer_type = 'AdamW8bit' + optimizer_type = args.optimizer_type + if args.use_8bit_adam: + assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" + assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" + optimizer_type = "AdamW8bit" - elif args.use_lion_optimizer: - assert ( - optimizer_type is None or optimizer_type == '' - ), 'both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています' - optimizer_type = 'Lion' + elif args.use_lion_optimizer: + assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" + optimizer_type = "Lion" - if optimizer_type is None or optimizer_type == '': - optimizer_type = 'AdamW' - optimizer_type = optimizer_type.lower() + if optimizer_type is None or optimizer_type == "": + optimizer_type = "AdamW" + optimizer_type = optimizer_type.lower() - # 引数を分解する:boolとfloat、tupleのみ対応 - optimizer_kwargs = {} - if args.optimizer_args is not None and len(args.optimizer_args) > 0: - for arg in args.optimizer_args: - key, value = arg.split('=') + # 引数を分解する:boolとfloat、tupleのみ対応 + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split('=') - value = value.split(',') - for i in range(len(value)): - if value[i].lower() == 'true' or value[i].lower() == 'false': - value[i] = value[i].lower() == 'true' - else: - value[i] = float(value[i]) - if len(value) == 1: - value = value[0] - else: - value = tuple(value) - - optimizer_kwargs[key] = value - # print("optkwargs:", optimizer_kwargs) - - lr = args.learning_rate - - if optimizer_type == 'AdamW8bit'.lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - 'No bitsand bytes / bitsandbytesがインストールされていないようです' - ) - print(f'use 8-bit AdamW optimizer | {optimizer_kwargs}') - optimizer_class = bnb.optim.AdamW8bit - optimizer = optimizer_class( - trainable_params, lr=lr, **optimizer_kwargs - ) - - elif optimizer_type == 'SGDNesterov8bit'.lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - 'No bitsand bytes / bitsandbytesがインストールされていないようです' - ) - print(f'use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}') - if 'momentum' not in optimizer_kwargs: - print( - f'8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します' - ) - optimizer_kwargs['momentum'] = 0.9 - - optimizer_class = bnb.optim.SGD8bit - optimizer = optimizer_class( - trainable_params, lr=lr, nesterov=True, **optimizer_kwargs - ) - - elif optimizer_type == 'Lion'.lower(): - try: - import lion_pytorch - except ImportError: - raise ImportError( - 'No lion_pytorch / lion_pytorch がインストールされていないようです' - ) - print(f'use Lion optimizer | {optimizer_kwargs}') - optimizer_class = lion_pytorch.Lion - optimizer = optimizer_class( - trainable_params, lr=lr, **optimizer_kwargs - ) - - elif optimizer_type == 'SGDNesterov'.lower(): - print(f'use SGD with Nesterov optimizer | {optimizer_kwargs}') - if 'momentum' not in optimizer_kwargs: - print( - f'SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します' - ) - optimizer_kwargs['momentum'] = 0.9 - - optimizer_class = torch.optim.SGD - optimizer = optimizer_class( - trainable_params, lr=lr, nesterov=True, **optimizer_kwargs - ) - - elif optimizer_type == 'DAdaptation'.lower(): - try: - import dadaptation - except ImportError: - raise ImportError('No dadaptation / dadaptation がインストールされていないようです') - print(f'use D-Adaptation Adam optimizer | {optimizer_kwargs}') - - min_lr = lr - if ( - type(trainable_params) == list - and type(trainable_params[0]) == dict - ): - for group in trainable_params: - min_lr = min(min_lr, group.get('lr', lr)) - - if min_lr <= 0.1: - print( - f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}' - ) - print('recommend option: lr=1.0 / 推奨は1.0です') - - optimizer_class = dadaptation.DAdaptAdam - optimizer = optimizer_class( - trainable_params, lr=lr, **optimizer_kwargs - ) - - elif optimizer_type == 'Adafactor'.lower(): - # 引数を確認して適宜補正する - if 'relative_step' not in optimizer_kwargs: - optimizer_kwargs['relative_step'] = True # default - if not optimizer_kwargs['relative_step'] and optimizer_kwargs.get( - 'warmup_init', False - ): - print( - f'set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします' - ) - optimizer_kwargs['relative_step'] = True - print(f'use Adafactor optimizer | {optimizer_kwargs}') - - if optimizer_kwargs['relative_step']: - print(f'relative_step is true / relative_stepがtrueです') - if lr != 0.0: - print( - f'learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます' - ) - args.learning_rate = None - - # trainable_paramsがgroupだった時の処理:lrを削除する - if ( - type(trainable_params) == list - and type(trainable_params[0]) == dict - ): - has_group_lr = False - for group in trainable_params: - p = group.pop('lr', None) - has_group_lr = has_group_lr or (p is not None) - - if has_group_lr: - # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない - print( - f'unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます' - ) - args.unet_lr = None - args.text_encoder_lr = None - - if args.lr_scheduler != 'adafactor': - print( - f'use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します' - ) - args.lr_scheduler = f'adafactor:{lr}' # ちょっと微妙だけど - - lr = None + value = value.split(",") + for i in range(len(value)): + if value[i].lower() == "true" or value[i].lower() == "false": + value[i] = (value[i].lower() == "true") else: - if args.max_grad_norm != 0.0: - print( - f'because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません' - ) - if args.lr_scheduler != 'constant_with_warmup': - print( - f'constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません' - ) - if optimizer_kwargs.get('clip_threshold', 1.0) != 1.0: - print( - f'clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません' - ) + value[i] = float(value[i]) + if len(value) == 1: + value = value[0] + else: + value = tuple(value) - optimizer_class = transformers.optimization.Adafactor - optimizer = optimizer_class( - trainable_params, lr=lr, **optimizer_kwargs - ) + optimizer_kwargs[key] = value + # print("optkwargs:", optimizer_kwargs) - elif optimizer_type == 'AdamW'.lower(): - print(f'use AdamW optimizer | {optimizer_kwargs}') - optimizer_class = torch.optim.AdamW - optimizer = optimizer_class( - trainable_params, lr=lr, **optimizer_kwargs - ) + lr = args.learning_rate + if optimizer_type == "AdamW8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = bnb.optim.SGD8bit + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "Lion".lower(): + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") + print(f"use Lion optimizer | {optimizer_kwargs}") + optimizer_class = lion_pytorch.Lion + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov".lower(): + print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = torch.optim.SGD + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "DAdaptation".lower(): + try: + import dadaptation + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + + min_lr = lr + if type(trainable_params) == list and type(trainable_params[0]) == dict: + for group in trainable_params: + min_lr = min(min_lr, group.get("lr", lr)) + + if min_lr <= 0.1: + print( + f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}') + print('recommend option: lr=1.0 / 推奨は1.0です') + + optimizer_class = dadaptation.DAdaptAdam + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Adafactor".lower(): + # 引数を確認して適宜補正する + if "relative_step" not in optimizer_kwargs: + optimizer_kwargs["relative_step"] = True # default + if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): + print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + optimizer_kwargs["relative_step"] = True + print(f"use Adafactor optimizer | {optimizer_kwargs}") + + if optimizer_kwargs["relative_step"]: + print(f"relative_step is true / relative_stepがtrueです") + if lr != 0.0: + print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + args.learning_rate = None + + # trainable_paramsがgroupだった時の処理:lrを削除する + if type(trainable_params) == list and type(trainable_params[0]) == dict: + has_group_lr = False + for group in trainable_params: + p = group.pop("lr", None) + has_group_lr = has_group_lr or (p is not None) + + if has_group_lr: + # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない + print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + args.unet_lr = None + args.text_encoder_lr = None + + if args.lr_scheduler != "adafactor": + print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど + + lr = None else: - # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - print(f'use {optimizer_type} | {optimizer_kwargs}') - if '.' not in optimizer_type: - optimizer_module = torch.optim - else: - values = optimizer_type.split('.') - optimizer_module = importlib.import_module('.'.join(values[:-1])) - optimizer_type = values[-1] + if args.max_grad_norm != 0.0: + print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません") + if args.lr_scheduler != "constant_with_warmup": + print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: + print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") - optimizer_class = getattr(optimizer_module, optimizer_type) - optimizer = optimizer_class( - trainable_params, lr=lr, **optimizer_kwargs - ) + optimizer_class = transformers.optimization.Adafactor + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - optimizer_name = ( - optimizer_class.__module__ + '.' + optimizer_class.__name__ - ) - optimizer_args = ','.join( - [f'{k}={v}' for k, v in optimizer_kwargs.items()] - ) + elif optimizer_type == "AdamW".lower(): + print(f"use AdamW optimizer | {optimizer_kwargs}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - return optimizer_name, optimizer_args, optimizer + else: + # 任意のoptimizerを使う + optimizer_type = args.optimizer_type # lowerでないやつ(微妙) + print(f"use {optimizer_type} | {optimizer_kwargs}") + if "." not in optimizer_type: + optimizer_module = torch.optim + else: + values = optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + + return optimizer_name, optimizer_args, optimizer # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler @@ -2495,560 +1845,545 @@ def get_scheduler_fix( num_cycles: int = 1, power: float = 1.0, ): - """ - Unified API to get any scheduler from its name. - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_cycles (`int`, *optional*): - The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. - power (`float`, *optional*, defaults to 1.0): - Power factor. See `POLYNOMIAL` scheduler - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - """ - if name.startswith('adafactor'): - assert ( - type(optimizer) == transformers.optimization.Adafactor - ), f'adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください' - initial_lr = float(name.split(':')[1]) - # print("adafactor scheduler init lr", initial_lr) - return transformers.optimization.AdafactorSchedule( - optimizer, initial_lr - ) + """ + Unified API to get any scheduler from its name. + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + if name.startswith("adafactor"): + assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" + initial_lr = float(name.split(':')[1]) + # print("adafactor scheduler init lr", initial_lr) + return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError( - f'{name} requires `num_warmup_steps`, please provide that argument.' - ) + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError( - f'{name} requires `num_training_steps`, please provide that argument.' - ) - - if name == SchedulerType.COSINE_WITH_RESTARTS: - return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - num_cycles=num_cycles, - ) - - if name == SchedulerType.POLYNOMIAL: - return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - power=power, - ) + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + if name == SchedulerType.COSINE_WITH_RESTARTS: return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles ) + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power + ) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): - # backward compatibility - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - args.caption_extention = None + # backward compatibility + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + args.caption_extention = None - if args.cache_latents: - assert ( - not args.color_aug - ), 'when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません' - assert ( - not args.random_crop - ), 'when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません' + # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" + if args.resolution is not None: + args.resolution = tuple([int(r) for r in args.resolution.split(',')]) + if len(args.resolution) == 1: + args.resolution = (args.resolution[0], args.resolution[0]) + assert len(args.resolution) == 2, \ + f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" - # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" - if args.resolution is not None: - args.resolution = tuple([int(r) for r in args.resolution.split(',')]) - if len(args.resolution) == 1: - args.resolution = (args.resolution[0], args.resolution[0]) - assert ( - len(args.resolution) == 2 - ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + if args.face_crop_aug_range is not None: + args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) + assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \ + f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" + else: + args.face_crop_aug_range = None - if args.face_crop_aug_range is not None: - args.face_crop_aug_range = tuple( - [float(r) for r in args.face_crop_aug_range.split(',')] - ) - assert ( - len(args.face_crop_aug_range) == 2 - and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1] - ), f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" - else: - args.face_crop_aug_range = None - - if support_metadata: - if args.in_json is not None and (args.color_aug or args.random_crop): - print( - f'latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます' - ) + if support_metadata: + if args.in_json is not None and (args.color_aug or args.random_crop): + print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます") def load_tokenizer(args: argparse.Namespace): - print('prepare tokenizer') + print("prepare tokenizer") + original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH + + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_')) + if os.path.exists(local_tokenizer_path): + print(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: if args.v2: - tokenizer = CLIPTokenizer.from_pretrained( - V2_STABLE_DIFFUSION_PATH, subfolder='tokenizer' - ) + tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) - if args.max_token_length is not None: - print(f'update token length: {args.max_token_length}') - return tokenizer + tokenizer = CLIPTokenizer.from_pretrained(original_path) + + if hasattr(args, "max_token_length") and args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + print(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer def prepare_accelerator(args: argparse.Namespace): - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = 'tensorboard' - log_prefix = '' if args.log_prefix is None else args.log_prefix - logging_dir = ( - args.logging_dir - + '/' - + log_prefix - + time.strftime('%Y%m%d%H%M%S', time.localtime()) - ) + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=log_with, - logging_dir=logging_dir, - ) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, + log_with=log_with, logging_dir=logging_dir) - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model('dummy', True) - print('Using accelerator 0.15.0 or above.') - except TypeError: - accelerator_0_15 = False + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) - return accelerator, unwrap_model + return accelerator, unwrap_model def prepare_dtype(args: argparse.Namespace): - weight_dtype = torch.float32 - if args.mixed_precision == 'fp16': - weight_dtype = torch.float16 - elif args.mixed_precision == 'bf16': - weight_dtype = torch.bfloat16 + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 - save_dtype = None - if args.save_precision == 'fp16': - save_dtype = torch.float16 - elif args.save_precision == 'bf16': - save_dtype = torch.bfloat16 - elif args.save_precision == 'float': - save_dtype = torch.float32 + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 - return weight_dtype, save_dtype + return weight_dtype, save_dtype def load_target_model(args: argparse.Namespace, weight_dtype): - name_or_path = args.pretrained_model_name_or_path - name_or_path = ( - os.readlink(name_or_path) - if os.path.islink(name_or_path) - else name_or_path - ) - load_stable_diffusion_format = os.path.isfile( - name_or_path - ) # determine SD or Diffusers - if load_stable_diffusion_format: - print('load StableDiffusion checkpoint') - ( - text_encoder, - vae, - unet, - ) = model_util.load_models_from_stable_diffusion_checkpoint( - args.v2, name_or_path - ) - else: - print('load Diffusers pretrained models') - try: - pipe = StableDiffusionPipeline.from_pretrained( - name_or_path, tokenizer=None, safety_checker=None - ) - except EnvironmentError as ex: - print( - f'model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}' - ) - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - del pipe + name_or_path = args.pretrained_model_name_or_path + name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) + else: + print("load Diffusers pretrained models") + try: + pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) + except EnvironmentError as ex: + print( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}") + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) - print('additional VAE loaded') + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") - return text_encoder, vae, unet, load_stable_diffusion_format + return text_encoder, vae, unet, load_stable_diffusion_format def patch_accelerator_for_fp16_training(accelerator): - org_unscale_grads = accelerator.scaler._unscale_grads_ + org_unscale_grads = accelerator.scaler._unscale_grads_ - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer -def get_hidden_states( - args: argparse.Namespace, - input_ids, - tokenizer, - text_encoder, - weight_dtype=None, -): - # with no_token_padding, the length is not max length, return result immediately - if input_ids.size()[-1] != tokenizer.model_max_length: - return text_encoder(input_ids)[0] +def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): + # with no_token_padding, the length is not max length, return result immediately + if input_ids.size()[-1] != tokenizer.model_max_length: + return text_encoder(input_ids)[0] - b_size = input_ids.size()[0] - input_ids = input_ids.reshape( - (-1, tokenizer.model_max_length) - ) # batch_size*3, 77 + b_size = input_ids.size()[0] + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 - if args.clip_skip is None: - encoder_hidden_states = text_encoder(input_ids)[0] + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] + else: + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) else: - enc_out = text_encoder( - input_ids, output_hidden_states=True, return_dict=True - ) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm( - encoder_hidden_states - ) + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) - # bs*3, 77, 768 or 1024 - encoder_hidden_states = encoder_hidden_states.reshape( - (b_size, -1, encoder_hidden_states.shape[-1]) - ) + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) - if args.max_token_length is not None: - if args.v2: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [ - encoder_hidden_states[:, 0].unsqueeze(1) - ] # - for i in range( - 1, args.max_token_length, tokenizer.model_max_length - ): - chunk = encoder_hidden_states[ - :, i : i + tokenizer.model_max_length - 2 - ] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if ( - input_ids[j, 1] == tokenizer.eos_token - ): # 空、つまり ...のパターン - chunk[j, 0] = chunk[ - j, 1 - ] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append( - encoder_hidden_states[:, -1].unsqueeze(1) - ) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) - else: - # v1: ... の三連を ... へ戻す - states_list = [ - encoder_hidden_states[:, 0].unsqueeze(1) - ] # - for i in range( - 1, args.max_token_length, tokenizer.model_max_length - ): - states_list.append( - encoder_hidden_states[ - :, i : i + tokenizer.model_max_length - 2 - ] - ) # の後から の前まで - states_list.append( - encoder_hidden_states[:, -1].unsqueeze(1) - ) # - encoder_hidden_states = torch.cat(states_list, dim=1) - - if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) - - return encoder_hidden_states + return encoder_hidden_states def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): - model_name = ( - DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name - ) - ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + ( - '.safetensors' if use_safetensors else '.ckpt' - ) - return model_name, ckpt_name + model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") + return model_name, ckpt_name -def save_on_epoch_end( - args: argparse.Namespace, - save_func, - remove_old_func, - epoch_no: int, - num_train_epochs: int, -): - saving = ( - epoch_no % args.save_every_n_epochs == 0 - and epoch_no < num_train_epochs - ) - if saving: - os.makedirs(args.output_dir, exist_ok=True) - save_func() +def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + if saving: + os.makedirs(args.output_dir, exist_ok=True) + save_func() - if args.save_last_n_epochs is not None: - remove_epoch_no = ( - epoch_no - args.save_every_n_epochs * args.save_last_n_epochs - ) - remove_old_func(remove_epoch_no) - return saving + if args.save_last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + remove_old_func(remove_epoch_no) + return saving -def save_sd_model_on_epoch_end( - args: argparse.Namespace, - accelerator, - src_path: str, - save_stable_diffusion_format: bool, - use_safetensors: bool, - save_dtype: torch.dtype, - epoch: int, - num_train_epochs: int, - global_step: int, - text_encoder, - unet, - vae, -): - epoch_no = epoch + 1 - model_name, ckpt_name = get_epoch_ckpt_name( - args, use_safetensors, epoch_no - ) +def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae): + epoch_no = epoch + 1 + model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) - if save_stable_diffusion_format: + if save_stable_diffusion_format: + def save_sd(): + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, + src_path, epoch_no, global_step, save_dtype, vae) - def save_sd(): - ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f'saving checkpoint: {ckpt_file}') - model_util.save_stable_diffusion_checkpoint( - args.v2, - ckpt_file, - text_encoder, - unet, - src_path, - epoch_no, - global_step, - save_dtype, - vae, - ) + def remove_sd(old_epoch_no): + _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) - def remove_sd(old_epoch_no): - _, old_ckpt_name = get_epoch_ckpt_name( - args, use_safetensors, old_epoch_no - ) - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f'removing old checkpoint: {old_ckpt_file}') - os.remove(old_ckpt_file) + save_func = save_sd + remove_old_func = remove_sd + else: + def save_du(): + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + print(f"saving model: {out_dir}") + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + src_path, vae=vae, use_safetensors=use_safetensors) - save_func = save_sd - remove_old_func = remove_sd - else: + def remove_du(old_epoch_no): + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) + if os.path.exists(out_dir_old): + print(f"removing old model: {out_dir_old}") + shutil.rmtree(out_dir_old) - def save_du(): - out_dir = os.path.join( - args.output_dir, - EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no), - ) - print(f'saving model: {out_dir}') - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint( - args.v2, - out_dir, - text_encoder, - unet, - src_path, - vae=vae, - use_safetensors=use_safetensors, - ) + save_func = save_du + remove_old_func = remove_du - def remove_du(old_epoch_no): - out_dir_old = os.path.join( - args.output_dir, - EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no), - ) - if os.path.exists(out_dir_old): - print(f'removing old model: {out_dir_old}') - shutil.rmtree(out_dir_old) - - save_func = save_du - remove_old_func = remove_du - - saving = save_on_epoch_end( - args, save_func, remove_old_func, epoch_no, num_train_epochs - ) - if saving and args.save_state: - save_state_on_epoch_end(args, accelerator, model_name, epoch_no) + saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + if saving and args.save_state: + save_state_on_epoch_end(args, accelerator, model_name, epoch_no) -def save_state_on_epoch_end( - args: argparse.Namespace, accelerator, model_name, epoch_no -): - print('saving state.') - accelerator.save_state( - os.path.join( - args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no) - ) - ) +def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) - last_n_epochs = ( - args.save_last_n_epochs_state - if args.save_last_n_epochs_state - else args.save_last_n_epochs - ) - if last_n_epochs is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs - state_dir_old = os.path.join( - args.output_dir, - EPOCH_STATE_NAME.format(model_name, remove_epoch_no), - ) - if os.path.exists(state_dir_old): - print(f'removing old state: {state_dir_old}') - shutil.rmtree(state_dir_old) + last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + print(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) -def save_sd_model_on_train_end( - args: argparse.Namespace, - src_path: str, - save_stable_diffusion_format: bool, - use_safetensors: bool, - save_dtype: torch.dtype, - epoch: int, - global_step: int, - text_encoder, - unet, - vae, -): - model_name = ( - DEFAULT_LAST_OUTPUT_NAME - if args.output_name is None - else args.output_name - ) +def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae): + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - if save_stable_diffusion_format: - os.makedirs(args.output_dir, exist_ok=True) + if save_stable_diffusion_format: + os.makedirs(args.output_dir, exist_ok=True) - ckpt_name = model_name + ( - '.safetensors' if use_safetensors else '.ckpt' - ) - ckpt_file = os.path.join(args.output_dir, ckpt_name) + ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") + ckpt_file = os.path.join(args.output_dir, ckpt_name) - print( - f'save trained model as StableDiffusion checkpoint to {ckpt_file}' - ) - model_util.save_stable_diffusion_checkpoint( - args.v2, - ckpt_file, - text_encoder, - unet, - src_path, - epoch, - global_step, - save_dtype, - vae, - ) - else: - out_dir = os.path.join(args.output_dir, model_name) - os.makedirs(out_dir, exist_ok=True) + print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, + src_path, epoch, global_step, save_dtype, vae) + else: + out_dir = os.path.join(args.output_dir, model_name) + os.makedirs(out_dir, exist_ok=True) - print(f'save trained model as Diffusers to {out_dir}') - model_util.save_diffusers_checkpoint( - args.v2, - out_dir, - text_encoder, - unet, - src_path, - vae=vae, - use_safetensors=use_safetensors, - ) + print(f"save trained model as Diffusers to {out_dir}") + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + src_path, vae=vae, use_safetensors=use_safetensors) def save_state_on_train_end(args: argparse.Namespace, accelerator): - print('saving last state.') - os.makedirs(args.output_dir, exist_ok=True) - model_name = ( - DEFAULT_LAST_OUTPUT_NAME - if args.output_name is None - else args.output_name - ) - accelerator.save_state( - os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) - ) + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = 'scaled_linear' + + +def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None): + """ + 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない + clip skipは対応した + """ + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0: + return + + print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + # ここでCUDAのキャッシュクリアとかしたほうがいいのか…… + + org_vae_device = vae.device # CPUにいるはず + vae.to(device) + + # clip skip 対応のための wrapper を作る + if args.clip_skip is None: + text_encoder_or_wrapper = text_encoder + else: + class Wrapper(): + def __init__(self, tenc) -> None: + self.tenc = tenc + self.config = {} + super().__init__() + + def __call__(self, input_ids, attention_mask): + enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states) + pooled_output = enc_out['pooler_output'] + return encoder_hidden_states, pooled_output # 1st output is only used + + text_encoder_or_wrapper = Wrapper(text_encoder) + + # read prompts + with open(args.sample_prompts, 'rt', encoding='utf-8') as f: + prompts = f.readlines() + + # schedulerを用意する + sched_init_args = {} + if args.sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif args.sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms': + scheduler_cls = LMSDiscreteScheduler + elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler': + scheduler_cls = EulerDiscreteScheduler + elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a': + scheduler_cls = EulerAncestralDiscreteScheduler + elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args['algorithm_type'] = args.sample_sampler + elif args.sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif args.sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2': + scheduler_cls = KDPM2DiscreteScheduler + elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a': + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler + + if args.v_parameterization: + sched_init_args['prediction_type'] = 'v_prediction' + + scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args) + + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + scheduler.config.clip_sample = True + + pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer, + scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) + pipeline.to(device) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() + + with torch.no_grad(): + with accelerator.autocast(): + for i, prompt in enumerate(prompts): + prompt = prompt.strip() + if len(prompt) == 0 or prompt[0] == '#': + continue + + # subset of gen_img_diffusers + prompt_args = prompt.split(' --') + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + for parg in prompt_args: + try: + m = re.match(r'w (\d+)', parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue + + m = re.match(r'h (\d+)', parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue + + m = re.match(r'd (\d+)', parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue + + m = re.match(r's (\d+)', parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue + + m = re.match(r'n (.+)', parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] + + ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + + image.save(os.path.join(save_dir, img_filename)) + + torch.set_rng_state(rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + vae.to(org_vae_device) + # endregion # region 前処理用 class ImageLoadingDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths): + self.images = image_paths - def __len__(self): - return len(self.images) + def __len__(self): + return len(self.images) - def __getitem__(self, idx): - img_path = self.images[idx] + def __getitem__(self, idx): + img_path = self.images[idx] - try: - image = Image.open(img_path).convert('RGB') - # convert to tensor temporarily so dataloader will accept it - tensor_pil = transforms.functional.pil_to_tensor(image) - except Exception as e: - print( - f'Could not load image path / 画像を読み込めません: {img_path}, error: {e}' - ) - return None + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor_pil = transforms.functional.pil_to_tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None - return (tensor_pil, img_path) + return (tensor_pil, img_path) # endregion diff --git a/lora_gui.py b/lora_gui.py index d475040..c951144 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -40,6 +40,7 @@ from library.utilities import utilities_tab from library.merge_lora_gui import gradio_merge_lora_tab from library.verify_lora_gui import gradio_verify_lora_tab from library.resize_lora_gui import gradio_resize_lora_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 @@ -112,9 +113,13 @@ def save_configuration( optimizer, optimizer_args, noise_offset, - LoRA_type='Standard', - conv_dim=0, - conv_alpha=0, + LoRA_type, + conv_dim, + conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -223,9 +228,13 @@ def open_configuration( optimizer, optimizer_args, noise_offset, - LoRA_type='Standard', - conv_dim=0, - conv_alpha=0, + LoRA_type, + conv_dim, + conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -323,6 +332,10 @@ def train_model( LoRA_type, conv_dim, conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -544,8 +557,15 @@ def train_model( noise_offset=noise_offset, ) + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) + print(run_cmd) - + # Run the command if os.name == 'posix': os.system(run_cmd) @@ -826,11 +846,12 @@ def lora_tab( outputs=[cache_latents], ) - # optimizer.change( - # set_legacy_8bitadam, - # inputs=[optimizer, use_8bit_adam], - # outputs=[optimizer, use_8bit_adam], - # ) + ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) = sample_gradio_config() with gr.Tab('Tools'): gr.Markdown( @@ -927,6 +948,10 @@ def lora_tab( LoRA_type, conv_dim, conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ] button_open_config.click( diff --git a/requirements.txt b/requirements.txt index 4277d7d..27bde40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,9 @@ pytorch-lightning==1.9.0 safetensors==0.2.6 tensorboard==2.10.1 tk==0.1.0 +toml==0.10.2 transformers==4.26.0 +voluptuous==0.13.1 # for BLIP captioning fairscale==0.4.13 requests==2.28.2 diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 089ac0a..62fe8b6 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -36,6 +36,7 @@ from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) from library.utilities import utilities_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 @@ -106,6 +107,10 @@ def save_configuration( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -213,6 +218,10 @@ def open_configuration( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -299,6 +308,10 @@ def train_model( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -496,8 +509,15 @@ def train_model( elif template == 'style template': run_cmd += f' --use_style_template' + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) + print(run_cmd) - + # Run the command if os.name == 'posix': os.system(run_cmd) @@ -740,11 +760,14 @@ def ti_tab( inputs=[color_aug], outputs=[cache_latents], ) - # optimizer.change( - # set_legacy_8bitadam, - # inputs=[optimizer, use_8bit_adam], - # outputs=[optimizer, use_8bit_adam], - # ) + + ( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + ) = sample_gradio_config() + with gr.Tab('Tools'): gr.Markdown( 'This section provide Dreambooth tools to help setup your dataset...' @@ -832,6 +855,10 @@ def ti_tab( optimizer, optimizer_args, noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, ] button_open_config.click( diff --git a/train_README-ja.md b/train_README-ja.md new file mode 100644 index 0000000..bf0d9f9 --- /dev/null +++ b/train_README-ja.md @@ -0,0 +1,619 @@ +当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やスクリプトオプションについて説明します。 + +# 概要 + +あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 + + +以下について説明します。 + +1. 学習データの準備について(設定ファイルを用いる新形式) +1. Aspect Ratio Bucketingについて +1. 以前の指定形式(設定ファイルを用いずコマンドラインから指定) +1. fine tuning 方式のメタデータ準備:キャプションニングなど + +1.だけ実行すればとりあえず学習は可能です(学習については各スクリプトのドキュメントを参照)。2.以降は必要に応じて参照してください。 + + + +# 学習データの準備について + +任意のフォルダ(複数でも可)に学習データの画像ファイルを用意しておきます。`.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` をサポートします。リサイズなどの前処理は基本的に必要ありません。 + +ただし学習解像度(後述)よりも極端に小さい画像は使わないか、あらかじめ超解像AIなどで拡大しておくことをお勧めします。また極端に大きな画像(3000x3000ピクセル程度?)よりも大きな画像はエラーになる場合があるようですので事前に縮小してください。 + +学習時には、モデルに学ばせる画像データを整理し、スクリプトに対して指定する必要があります。学習データの数、学習対象、キャプション(画像の説明)が用意できるか否かなどにより、いくつかの方法で学習データを指定できます。以下の方式があります(それぞれの名前は一般的なものではなく、当リポジトリ独自の定義です)。正則化画像については後述します。 + +1. DreamBooth、class+identifier方式(正則化画像使用可) + + 特定の単語 (identifier) に学習対象を紐づけるように学習します。キャプションを用意する必要はありません。たとえば特定のキャラを学ばせる場合に使うとキャプションを用意する必要がない分、手軽ですが、髪型や服装、背景など学習データの全要素が identifier に紐づけられて学習されるため、生成時のプロンプトで服が変えられない、といった事態も起こりえます。 + +1. DreamBooth、キャプション方式(正則化画像使用可) + + 画像ごとにキャプションが記録されたテキストファイルを用意して学習します。たとえば特定のキャラを学ばせると、画像の詳細をキャプションに記述することで(白い服を着たキャラA、赤い服を着たキャラA、など)キャラとそれ以外の要素が分離され、より厳密にモデルがキャラだけを学ぶことが期待できます。 + +1. fine tuning方式(正則化画像使用不可) + + あらかじめキャプションをメタデータファイルにまとめます。タグとキャプションを分けて管理したり、学習を高速化するためlatentsを事前キャッシュしたりなどの機能をサポートします(いずれも別文書で説明しています)。 + +学習したいものと使用できる指定方法の組み合わせは以下の通りです。 + +| 学習対象または方法 | スクリプト | DB / class+identifier | DB / キャプション | fine tuning | +| ----- | ----- | ----- | ----- | ----- | +| モデルをfine tuning | `fine_tune.py`| x | x | o | +| モデルをDreamBooth | `train_db.py`| o | o | x | +| LoRA | `train_network.py`| o | o | o | +| Textual Invesion | `train_textual_inversion.py`| o | o | o | + +## どれを選ぶか + +LoRA、Textual Inversionについては、手軽にキャプションファイルを用意せずに学習したい場合はDreamBooth class+identifier、用意できるならDreamBooth キャプション方式がよいでしょう。学習データの枚数が多く、かつ正則化画像を使用しない場合はfine tuning方式も検討してください。 + +DreamBoothについても同様ですが、fine tuning方式は使えません。fine tuningの場合はfine tuning方式のみです。 + +# 各方式の指定方法について + +ここではそれぞれの指定方法で典型的なパターンについてだけ説明します。より詳細な指定方法については [データセット設定](./config_README-ja.md) をご覧ください。 + +# DreamBooth、class+identifier方式(正則化画像使用可) + +この方式では、各画像は `class identifier` というキャプションで学習されたのと同じことになります(`shs dog` など)。 + +## step 1. identifierとclassを決める + +学ばせたい対象を結びつける単語identifierと、対象の属するclassを決めます。 + +(instanceなどいろいろな呼び方がありますが、とりあえず元の論文に合わせます。) + +以下ごく簡単に説明します(詳しくは調べてください)。 + +classは学習対象の一般的な種別です。たとえば特定の犬種を学ばせる場合には、classはdogになります。アニメキャラならモデルによりboyやgirl、1boyや1girlになるでしょう。 + +identifierは学習対象を識別して学習するためのものです。任意の単語で構いませんが、元論文によると「tokinizerで1トークンになる3文字以下でレアな単語」が良いとのことです。 + +identifierとclassを使い、たとえば「shs dog」などでモデルを学習することで、学習させたい対象をclassから識別して学習できます。 + +画像生成時には「shs dog」とすれば学ばせた犬種の画像が生成されます。 + +(identifierとして私が最近使っているものを参考までに挙げると、``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny`` などです。本当は Danbooru Tag に含まれないやつがより望ましいです。) + +## step 2. 正則化画像を使うか否かを決め、使う場合には正則化画像を生成する + +正則化画像とは、前述のclass全体が、学習対象に引っ張られることを防ぐための画像です(language drift)。正則化画像を使わないと、たとえば `shs 1girl` で特定のキャラクタを学ばせると、単なる `1girl` というプロンプトで生成してもそのキャラに似てきます。これは `1girl` が学習時のキャプションに含まれているためです。 + +学習対象の画像と正則化画像を同時に学ばせることで、class は class のままで留まり、identifier をプロンプトにつけた時だけ学習対象が生成されるようになります。 + +LoRAやDreamBoothで特定のキャラだけ出てくればよい場合は、正則化画像を用いなくても良いといえます。 + +Textual Inversionでは用いなくてよいでしょう(学ばせる token string がキャプションに含まれない場合はなにも学習されないため)。 + +正則化画像としては、学習対象のモデルで、class 名だけで生成した画像を用いるのが一般的です(たとえば `1girl`)。ただし生成画像の品質が悪い場合には、プロンプトを工夫したり、ネットから別途ダウンロードした画像を用いることもできます。 + +(正則化画像も学習されるため、その品質はモデルに影響します。) + +一般的には数百枚程度、用意するのが望ましいようです(枚数が少ないと class 画像が一般化されずそれらの特徴を学んでしまいます)。 + +生成画像を使う場合、通常、生成画像のサイズは学習解像度(より正確にはbucketの解像度、後述)にあわせてください。 + +## step 2. 設定ファイルの記述 + +テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。 + +(`#` で始まっている部分はコメントですので、このままコピペしてそのままでもよいですし、削除しても問題ありません。) + +```toml +[general] +enable_bucket = true # Aspect Ratio Bucketingを使うか否か + +[[datasets]] +resolution = 512 # 学習解像度 +batch_size = 4 # バッチサイズ + + [[datasets.subsets]] + image_dir = 'C:\hoge' # 学習用画像を入れたフォルダを指定 + class_tokens = 'hoge girl' # identifier class を指定 + num_repeats = 10 # 学習用画像の繰り返し回数 + + # 以下は正則化画像を用いる場合のみ記述する。用いない場合は削除する + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定 + class_tokens = 'girl' # class を指定 + num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい +``` + +基本的には以下を場所のみ書き換えれば学習できます。 + +1. 学習解像度 + + 数値1つを指定すると正方形(`512`なら512x512)、鍵カッコカンマ区切りで2つ指定すると横×縦(`[512,768]`なら512x768)になります。SD1.x系ではもともとの学習解像度は512です。`[512,768]` 等の大きめの解像度を指定すると縦長、横長画像生成時の破綻を小さくできるかもしれません。SD2.x 768系では `768` です。 + +1. バッチサイズ + + 同時に何件のデータを学習するかを指定します。GPUのVRAMサイズ、学習解像度によって変わってきます。またfine tuning/DreamBooth/LoRA等でも変わってきますので、詳しくは各スクリプトの説明をご覧ください。 + +1. フォルダ指定 + + 学習用画像、正則化画像(使用する場合のみ)のフォルダを指定します。画像データが含まれているフォルダそのものを指定します。 + +1. identifier と class の指定 + + 前述のサンプルの通りです。 + +1. 繰り返し回数 + + 後述します。 + +### 繰り返し回数について + +繰り返し回数は、正則化画像の枚数と学習用画像の枚数を調整するために用いられます。正則化画像の枚数は学習用画像よりも多いため、学習用画像を繰り返して枚数を合わせ、1対1の比率で学習できるようにします。 + +繰り返し回数は「 __学習用画像の繰り返し回数×学習用画像の枚数≧正則化画像の繰り返し回数×正則化画像の枚数__ 」となるように指定してください。 + +(1 epoch(データが一周すると1 epoch)のデータ数が「学習用画像の繰り返し回数×学習用画像の枚数」となります。正則化画像の枚数がそれより多いと、余った部分の正則化画像は使用されません。) + +## step 3. 学習 + +それぞれのドキュメントを参考に学習を行ってください。 + +# DreamBooth、キャプション方式(正則化画像使用可) + +この方式では各画像はキャプションで学習されます。 + +## step 1. キャプションファイルを準備する + +学習用画像のフォルダに、画像と同じファイル名で、拡張子 `.caption`(設定で変えられます)のファイルを置いてください。それぞれのファイルは1行のみとしてください。エンコーディングは `UTF-8` です。 + +## step 2. 正則化画像を使うか否かを決め、使う場合には正則化画像を生成する + +class+identifier形式と同様です。なお正則化画像にもキャプションを付けることができますが、通常は不要でしょう。 + +## step 2. 設定ファイルの記述 + +テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。 + +```toml +[general] +enable_bucket = true # Aspect Ratio Bucketingを使うか否か + +[[datasets]] +resolution = 512 # 学習解像度 +batch_size = 4 # バッチサイズ + + [[datasets.subsets]] + image_dir = 'C:\hoge' # 学習用画像を入れたフォルダを指定 + caption_extension = '.caption' # キャプションファイルの拡張子 .txt を使う場合には書き換える + num_repeats = 10 # 学習用画像の繰り返し回数 + + # 以下は正則化画像を用いる場合のみ記述する。用いない場合は削除する + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定 + class_tokens = 'girl' # class を指定 + num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい +``` + +基本的には以下を場所のみ書き換えれば学習できます。特に記述がない部分は class+identifier 方式と同じです。 + +1. 学習解像度 +1. バッチサイズ +1. フォルダ指定 +1. キャプションファイルの拡張子 + + 任意の拡張子を指定できます。 +1. 繰り返し回数 + +## step 3. 学習 + +それぞれのドキュメントを参考に学習を行ってください。 + +# fine tuning 方式 + +## step 1. メタデータを準備する + +キャプションやタグをまとめた管理用ファイルをメタデータと呼びます。json形式で拡張子は `.json` + です。作成方法は長くなりますのでこの文書の末尾に書きました。 + +## step 2. 設定ファイルの記述 + +テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。 + +```toml +[general] +shuffle_caption = true +keep_tokens = 1 + +[[datasets]] +resolution = 512 # 学習解像度 +batch_size = 4 # バッチサイズ + + [[datasets.subsets]] + image_dir = 'C:\piyo' # 学習用画像を入れたフォルダを指定 + metadata_file = 'C:\piyo\piyo_md.json' # メタデータファイル名 +``` + +基本的には以下を場所のみ書き換えれば学習できます。特に記述がない部分は DreamBooth, class+identifier 方式と同じです。 + +1. 学習解像度 +1. バッチサイズ +1. フォルダ指定 +1. メタデータファイル名 + + 後述の方法で作成したメタデータファイルを指定します。 + + +## step 3. 学習 + +それぞれのドキュメントを参考に学習を行ってください。 + +# Aspect Ratio Bucketing について + +Stable Diffusion のv1は512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくキャプションと画像の関係が学習されることが期待されます。 + +また任意の解像度で学習するため、事前に画像データの縦横比を統一しておく必要がなくなります。 + +設定で有効、向こうが切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。 + +学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位(デフォルト、変更可)で縦横に調整、作成されます。 + +機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。 + +# 以前のデータ指定方法 + +フォルダ名で繰り返し回数を指定する方法です。 + +## step 1. 学習用画像の準備 + +学習用画像を格納するフォルダを作成します。 __さらにその中に__ 、以下の名前でディレクトリを作成します。 + +``` +<繰り返し回数>_ +``` + +間の``_``を忘れないでください。 + +たとえば「sls frog」というプロンプトで、データを20回繰り返す場合、「20_sls frog」となります。以下のようになります。 + +![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png) + +### 複数class、複数対象(identifier)の学習 + +方法は単純で、学習用画像のフォルダ内に ``繰り返し回数_ `` のフォルダを複数、正則化画像フォルダにも同様に ``繰り返し回数_`` のフォルダを複数、用意してください。 + +たとえば「sls frog」と「cpc rabbit」を同時に学習する場合、以下のようになります。 + +![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png) + +classがひとつで対象が複数の場合、正則化画像フォルダはひとつで構いません。たとえば1girlにキャラAとキャラBがいる場合は次のようにします。 + +- train_girls + - 10_sls 1girl + - 10_cpc 1girl +- reg_girls + - 1_1girl + +### DreamBoothでキャプションを使う + +学習用画像、正則化画像のフォルダに、画像と同じファイル名で、拡張子.caption(オプションで変えられます)のファイルを置くと、そのファイルからキャプションを読み込みプロンプトとして学習します。 + +※それらの画像の学習に、フォルダ名(identifier class)は使用されなくなります。 + +キャプションファイルの拡張子はデフォルトで.captionです。学習スクリプトの `--caption_extension` オプションで変更できます。`--shuffle_caption` オプションで学習時のキャプションについて、カンマ区切りの各部分をシャッフルしながら学習します。 + +## step 2. 正則化画像の準備 + +正則化画像を使う場合の手順です。 + +正則化画像を格納するフォルダを作成します。 __さらにその中に__ ``<繰り返し回数>_`` という名前でディレクトリを作成します。 + +たとえば「frog」というプロンプトで、データを繰り返さない(1回だけ)場合、以下のようになります。 + +![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png) + + +## step 3. 学習の実行 + +各学習スクリプトを実行します。 `--train_data_dir` オプションで前述の学習用データのフォルダを(__画像を含むフォルダではなく、その親フォルダ__)、`--reg_data_dir` オプションで正則化画像のフォルダ(__画像を含むフォルダではなく、その親フォルダ__)を指定してください。 + + + +# メタデータファイルの作成 + +## 教師データの用意 + +前述のように学習させたい画像データを用意し、任意のフォルダに入れてください。 + +たとえば以下のように画像を格納します。 + +![教師データフォルダのスクショ](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png) + +## 自動キャプショニング + +キャプションを使わずタグだけで学習する場合はスキップしてください。 + +また手動でキャプションを用意する場合、キャプションは教師データ画像と同じディレクトリに、同じファイル名、拡張子.caption等で用意してください。各ファイルは1行のみのテキストファイルとします。 + +### BLIPによるキャプショニング + +最新版ではBLIPのダウンロード、重みのダウンロード、仮想環境の追加は不要になりました。そのままで動作します。 + +finetuneフォルダ内のmake_captions.pyを実行します。 + +``` +python finetune\make_captions.py --batch_size <バッチサイズ> <教師データフォルダ> +``` + +バッチサイズ8、教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。 + +``` +python finetune\make_captions.py --batch_size 8 ..\train_data +``` + +キャプションファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.captionで作成されます。 + +batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。 +max_lengthオプションでキャプションの最大長を指定できます。デフォルトは75です。モデルをトークン長225で学習する場合には長くしても良いかもしれません。 +caption_extensionオプションでキャプションの拡張子を変更できます。デフォルトは.captionです(.txtにすると後述のDeepDanbooruと競合します)。 + +複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。 + +なお、推論にランダム性があるため、実行するたびに結果が変わります。固定する場合には--seedオプションで `--seed 42` のように乱数seedを指定してください。 + +その他のオプションは `--help` でヘルプをご参照ください(パラメータの意味についてはドキュメントがまとまっていないようで、ソースを見るしかないようです)。 + +デフォルトでは拡張子.captionでキャプションファイルが生成されます。 + +![captionが生成されたフォルダ](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png) + +たとえば以下のようなキャプションが付きます。 + +![キャプションと画像](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png) + +## DeepDanbooruによるタグ付け + +danbooruタグのタグ付け自体を行わない場合は「キャプションとタグ情報の前処理」に進んでください。 + +タグ付けはDeepDanbooruまたはWD14Taggerで行います。WD14Taggerのほうが精度が良いようです。WD14Taggerでタグ付けする場合は、次の章へ進んでください。 + +### 環境整備 + +DeepDanbooru https://github.com/KichangKim/DeepDanbooru を作業フォルダにcloneしてくるか、zipをダウンロードして展開します。私はzipで展開しました。 +またDeepDanbooruのReleasesのページ https://github.com/KichangKim/DeepDanbooru/releases の「DeepDanbooru Pretrained Model v3-20211112-sgd-e28」のAssetsから、deepdanbooru-v3-20211112-sgd-e28.zipをダウンロードしてきてDeepDanbooruのフォルダに展開します。 + +以下からダウンロードします。Assetsをクリックして開き、そこからダウンロードします。 + +![DeepDanbooruダウンロードページ](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png) + +以下のようなこういうディレクトリ構造にしてください + +![DeepDanbooruのディレクトリ構造](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png) + +Diffusersの環境に必要なライブラリをインストールします。DeepDanbooruのフォルダに移動してインストールします(実質的にはtensorflow-ioが追加されるだけだと思います)。 + +``` +pip install -r requirements.txt +``` + +続いてDeepDanbooru自体をインストールします。 + +``` +pip install . +``` + +以上でタグ付けの環境整備は完了です。 + +### タグ付けの実施 +DeepDanbooruのフォルダに移動し、deepdanbooruを実行してタグ付けを行います。 + +``` +deepdanbooru evaluate <教師データフォルダ> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。 + +``` +deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。1件ずつ処理されるためわりと遅いです。 + +複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。 + +以下のように生成されます。 + +![DeepDanbooruの生成ファイル](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png) + +こんな感じにタグが付きます(すごい情報量……)。 + +![DeepDanbooruタグと画像](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png) + +## WD14Taggerによるタグ付け + +DeepDanbooruの代わりにWD14Taggerを用いる手順です。 + +Automatic1111氏のWebUIで使用しているtaggerを利用します。こちらのgithubページ(https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。 + +最初の環境整備で必要なモジュールはインストール済みです。また重みはHugging Faceから自動的にダウンロードしてきます。 + +### タグ付けの実施 + +スクリプトを実行してタグ付けを行います。 +``` +python tag_images_by_wd14_tagger.py --batch_size <バッチサイズ> <教師データフォルダ> +``` + +教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。 +``` +python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data +``` + +初回起動時にはモデルファイルがwd14_tagger_modelフォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。以下のようになります。 + +![ダウンロードされたファイル](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png) + +タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。 + +![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) + +![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) + +threshオプションで、判定されたタグのconfidence(確信度)がいくつ以上でタグをつけるかが指定できます。デフォルトはWD14Taggerのサンプルと同じ0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。 + +batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。caption_extensionオプションでタグファイルの拡張子を変更できます。デフォルトは.txtです。 + +model_dirオプションでモデルの保存先フォルダを指定できます。 + +またforce_downloadオプションを指定すると保存先フォルダがあってもモデルを再ダウンロードします。 + +複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。 + +## キャプションとタグ情報の前処理 + +スクリプトから処理しやすいようにキャプションとタグをメタデータとしてひとつのファイルにまとめます。 + +### キャプションの前処理 + +キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してください(キャプションを学習に使わない場合は実行不要です)(実際は1行で記述します、以下同様)。`--full_path` オプションを指定してメタデータに画像ファイルの場所をフルパスで格納します。このオプションを省略すると相対パスで記録されますが、フォルダ指定が `.toml` ファイル内で別途必要になります。 + +``` +python merge_captions_to_metadata.py --full_apth <教師データフォルダ> +  --in_json <読み込むメタデータファイル名> <メタデータファイル名> +``` + +メタデータファイル名は任意の名前です。 +教師データがtrain_data、読み込むメタデータファイルなし、メタデータファイルがmeta_cap.jsonの場合、以下のようになります。 + +``` +python merge_captions_to_metadata.py --full_path train_data meta_cap.json +``` + +caption_extensionオプションでキャプションの拡張子を指定できます。 + +複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。 + +``` +python merge_captions_to_metadata.py --full_path + train_data1 meta_cap1.json +python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json + train_data2 meta_cap2.json +``` + +in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。 + +__※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__ + +### タグの前処理 + +同様にタグもメタデータにまとめます(タグを学習に使わない場合は実行不要です)。 +``` +python merge_dd_tags_to_metadata.py --full_path <教師データフォルダ> + --in_json <読み込むメタデータファイル名> <書き込むメタデータファイル名> +``` + +先と同じディレクトリ構成で、meta_cap.jsonを読み、meta_cap_dd.jsonに書きだす場合、以下となります。 +``` +python merge_dd_tags_to_metadata.py --full_path train_data --in_json meta_cap.json meta_cap_dd.json +``` + +複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。 + +``` +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json + train_data1 meta_cap_dd1.json +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json + train_data2 meta_cap_dd2.json +``` + +in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。 + +__※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__ + +### キャプションとタグのクリーニング + +ここまででメタデータファイルにキャプションとDeepDanbooruのタグがまとめられています。ただ自動キャプショニングにしたキャプションは表記ゆれなどがあり微妙(※)ですし、タグにはアンダースコアが含まれていたりratingが付いていたりしますので(DeepDanbooruの場合)、エディタの置換機能などを用いてキャプションとタグのクリーニングをしたほうがいいでしょう。 + +※たとえばアニメ絵の少女を学習する場合、キャプションにはgirl/girls/woman/womenなどのばらつきがあります。また「anime girl」なども単に「girl」としたほうが適切かもしれません。 + +クリーニング用のスクリプトが用意してありますので、スクリプトの内容を状況に応じて編集してお使いください。 + +(教師データフォルダの指定は不要になりました。メタデータ内の全データをクリーニングします。) + +``` +python clean_captions_and_tags.py <読み込むメタデータファイル名> <書き込むメタデータファイル名> +``` + +--in_jsonは付きませんのでご注意ください。たとえば次のようになります。 + +``` +python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json +``` + +以上でキャプションとタグの前処理は完了です。 + +## latentsの事前取得 + +※ このステップは必須ではありません。省略しても学習時にlatentsを取得しながら学習できます。 +また学習時に `random_crop` や `color_aug` などを行う場合にはlatentsの事前取得はできません(画像を毎回変えながら学習するため)。事前取得をしない場合、ここまでのメタデータで学習できます。 + +あらかじめ画像の潜在表現を取得しディスクに保存しておきます。それにより、学習を高速に進めることができます。あわせてbucketing(教師データをアスペクト比に応じて分類する)を行います。 + +作業フォルダで以下のように入力してください。 +``` +python prepare_buckets_latents.py --full_path <教師データフォルダ> + <読み込むメタデータファイル名> <書き込むメタデータファイル名> + + --batch_size <バッチサイズ> + --max_resolution <解像度 幅,高さ> + --mixed_precision <精度> +``` + +モデルがmodel.ckpt、バッチサイズ4、学習解像度は512\*512、精度no(float32)で、meta_clean.jsonからメタデータを読み込み、meta_lat.jsonに書き込む場合、以下のようになります。 + +``` +python prepare_buckets_latents.py --full_path + train_data meta_clean.json meta_lat.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no +``` + +教師データフォルダにnumpyのnpz形式でlatentsが保存されます。 + +解像度の最小サイズを--min_bucket_resoオプションで、最大サイズを--max_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。たとえば最小サイズに384を指定すると、256\*1024や320\*768などの解像度は使わなくなります。 +解像度を768\*768のように大きくした場合、最大サイズに1280などを指定すると良いでしょう。 + +--flip_augオプションを指定すると左右反転のaugmentation(データ拡張)を行います。疑似的にデータ量を二倍に増やすことができますが、データが左右対称でない場合に指定すると(例えばキャラクタの外見、髪型など)学習がうまく行かなくなります。 + + +(反転した画像についてもlatentsを取得し、\*\_flip.npzファイルを保存する単純な実装です。fline_tune.pyには特にオプション指定は必要ありません。\_flip付きのファイルがある場合、flip付き・なしのファイルを、ランダムに読み込みます。) + +バッチサイズはVRAM 12GBでももう少し増やせるかもしれません。 +解像度は64で割り切れる数字で、"幅,高さ"で指定します。解像度はfine tuning時のメモリサイズに直結します。VRAM 12GBでは512,512が限界と思われます(※)。16GBなら512,704や512,768まで上げられるかもしれません。なお256,256等にしてもVRAM 8GBでは厳しいようです(パラメータやoptimizerなどは解像度に関係せず一定のメモリが必要なため)。 + +※batch size 1の学習で12GB VRAM、640,640で動いたとの報告もありました。 + +以下のようにbucketingの結果が表示されます。 + +![bucketingの結果](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png) + +複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。 +``` +python prepare_buckets_latents.py --full_path + train_data1 meta_clean.json meta_lat1.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +python prepare_buckets_latents.py --full_path + train_data2 meta_lat1.json meta_lat2.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +``` +読み込み元と書き込み先を同じにすることも可能ですが別々の方が安全です。 + +__※引数を都度書き換えて、別のメタデータファイルに書き込むと安全です。__ + diff --git a/train_db.py b/train_db.py index 03fba1a..a302117 100644 --- a/train_db.py +++ b/train_db.py @@ -15,7 +15,11 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util -from library.train_util import DreamBoothDataset +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) def collate_fn(examples): @@ -33,24 +37,33 @@ def train(args): tokenizer = train_util.load_tokenizer(args) - train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, - tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) + else: + user_config = { + "datasets": [{ + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) + }] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.no_token_padding: - train_dataset.disable_token_padding() - - # 学習データのdropout率を設定する - train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) - - train_dataset.make_buckets() + train_dataset_group.disable_token_padding() if args.debug_dataset: - train_util.debug_dataset(train_dataset) + train_util.debug_dataset(train_dataset_group) return + if cache_latents: + assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + # acceleratorを準備する print("prepare accelerator") @@ -91,7 +104,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.cache_latents(vae) + train_dataset_group.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -126,7 +139,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -176,8 +189,8 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -198,7 +211,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset.set_current_epoch(epoch + 1) + train_dataset_group.set_current_epoch(epoch + 1) # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -278,6 +291,8 @@ def train(args): progress_bar.update(1) global_step += 1 + train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} @@ -309,6 +324,8 @@ def train(args): train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + is_main_process = accelerator.is_main_process if is_main_process: unet = unwrap_model(unet) @@ -336,6 +353,7 @@ if __name__ == '__main__': train_util.add_training_arguments(parser, True) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) parser.add_argument("--no_token_padding", action="store_true", help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") diff --git a/train_network.py b/train_network.py index 0ba290a..387590b 100644 --- a/train_network.py +++ b/train_network.py @@ -1,4 +1,3 @@ -from torch.cuda.amp import autocast from torch.nn.parallel import DistributedDataParallel as DDP import importlib import argparse @@ -12,11 +11,17 @@ import json from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler import library.train_util as train_util -from library.train_util import DreamBoothDataset, FineTuningDataset +from library.train_util import ( + DreamBoothDataset, +) +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) def collate_fn(examples): @@ -49,6 +54,7 @@ def train(args): cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None if args.seed is not None: set_seed(args.seed) @@ -56,35 +62,47 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # データセットを準備する - if use_dreambooth_method: - print("Use DreamBooth method.") - train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, - tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, - args.random_crop, args.debug_dataset) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) else: - print("Train with captions.") - train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, - args.dataset_repeats, args.debug_dataset) + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [{ + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) + }] + } + else: + print("Train with captions.") + user_config = { + "datasets": [{ + "subsets": [{ + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + }] + }] + } - # 学習データのdropout率を設定する - train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) - - train_dataset.make_buckets() + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: - train_util.debug_dataset(train_dataset) + train_util.debug_dataset(train_dataset_group) return - if len(train_dataset) == 0: + if len(train_dataset_group) == 0: print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)") return + if cache_latents: + assert train_dataset_group.is_latent_cacheable( + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + # acceleratorを準備する print("prepare accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) @@ -109,7 +127,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.cache_latents(vae) + train_dataset_group.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -153,7 +171,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -231,17 +249,19 @@ def train(args): args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + # TODO refactor metadata creation and move to util metadata = { "ss_session_id": session_id, # random integer indicating which group of epochs the model came from "ss_training_started_at": training_started_at, # unix timestamp @@ -249,12 +269,10 @@ def train(args): "ss_learning_rate": args.learning_rate, "ss_text_encoder_lr": args.text_encoder_lr, "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset.num_train_images, # includes repeating - "ss_num_reg_images": train_dataset.num_reg_images, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, "ss_gradient_checkpointing": args.gradient_checkpointing, "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, "ss_max_train_steps": args.max_train_steps, @@ -266,26 +284,12 @@ def train(args): "ss_mixed_precision": args.mixed_precision, "ss_full_fp16": bool(args.full_fp16), "ss_v2": bool(args.v2), - "ss_resolution": args.resolution, "ss_clip_skip": args.clip_skip, "ss_max_token_length": args.max_token_length, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), "ss_cache_latents": bool(args.cache_latents), - "ss_enable_bucket": bool(train_dataset.enable_bucket), - "ss_bucket_no_upscale": bool(train_dataset.bucket_no_upscale), - "ss_min_bucket_reso": train_dataset.min_bucket_reso, - "ss_max_bucket_reso": train_dataset.max_bucket_reso, "ss_seed": args.seed, "ss_lowram": args.lowram, - "ss_keep_tokens": args.keep_tokens, "ss_noise_offset": args.noise_offset, - "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), - "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), - "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), - "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment, # will not be updated after training "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), @@ -297,6 +301,132 @@ def train(args): "ss_prior_loss_weight": args.prior_loss_weight, } + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor + + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } + + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + } + + image_dir_or_metadata_file = None + if subset.image_dir: + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir + + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset + else: + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite + + subsets_metadata.append(subset_metadata) + + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count + } + + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + # merge tag frequency: + for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): + # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える + # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない + # なので、ここで複数datasetの回数を合算してもあまり意味はない + if ds_dir_name in tag_frequency: + continue + tag_frequency[ds_dir_name] = ds_freq_for_dir + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) + else: + # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir + assert len( + train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" + + dataset = train_dataset_group.datasets[0] + + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count + } + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count + } + + metadata.update({ + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + }) + # uncomment if another network is added # for key, value in net_kwargs.items(): # metadata["ss_arg_" + key] = value @@ -332,7 +462,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset.set_current_epoch(epoch + 1) + train_dataset_group.set_current_epoch(epoch + 1) metadata["ss_epoch"] = str(epoch+1) @@ -400,6 +530,8 @@ def train(args): progress_bar.update(1) global_step += 1 + train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + current_loss = loss.detach().item() if epoch == 0: loss_list.append(current_loss) @@ -445,6 +577,8 @@ def train(args): if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # end of epoch metadata["ss_epoch"] = str(num_train_epochs) @@ -480,6 +614,7 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b4ddd76..d91a78f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -11,7 +11,11 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util -from library.train_util import DreamBoothDataset, FineTuningDataset +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) imagenet_templates_small = [ "a photo of a {}", @@ -79,7 +83,6 @@ def train(args): train_util.prepare_dataset_args(args, True) cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) @@ -139,21 +142,35 @@ def train(args): print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - if use_dreambooth_method: - print("Use DreamBooth method.") - train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, - tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) else: - print("Train with captions.") - train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, - args.dataset_repeats, args.debug_dataset) + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [{ + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) + }] + } + else: + print("Train with captions.") + user_config = { + "datasets": [{ + "subsets": [{ + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + }] + }] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: @@ -163,20 +180,25 @@ def train(args): captions = [] for tmpl in templates: captions.append(tmpl.format(replace_to)) - train_dataset.add_replacement("", captions) - elif args.num_vectors_per_token > 1: - replace_to = " ".join(token_strings) - train_dataset.add_replacement(args.token_string, replace_to) - - train_dataset.make_buckets() + train_dataset_group.add_replacement("", captions) + else: + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None if args.debug_dataset: - train_util.debug_dataset(train_dataset, show_input_ids=True) + train_util.debug_dataset(train_dataset_group, show_input_ids=True) return - if len(train_dataset) == 0: + if len(train_dataset_group) == 0: print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return + if cache_latents: + assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -186,7 +208,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.cache_latents(vae) + train_dataset_group.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -205,7 +227,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -263,8 +285,8 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -283,12 +305,11 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset.set_current_epoch(epoch + 1) + train_dataset_group.set_current_epoch(epoch + 1) text_encoder.train() loss_total = 0 - bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): with torch.no_grad(): @@ -354,6 +375,9 @@ def train(args): progress_bar.update(1) global_step += 1 + train_util.sample_images(accelerator, args, None, global_step, accelerator.device, + vae, tokenizer, text_encoder, unet, prompt_replacement) + current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} @@ -376,8 +400,6 @@ def train(args): accelerator.wait_for_everyone() updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() - # d = updated_embs - bef_epo_embs - # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) if args.save_every_n_epochs is not None: model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name @@ -399,6 +421,9 @@ def train(args): if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, + vae, tokenizer, text_encoder, unet, prompt_replacement) + # end of epoch is_main_process = accelerator.is_main_process @@ -474,6 +499,7 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")