From 2deddd5f3c59f79dca6cd770c6bf2bc95ca97c41 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Thu, 9 Mar 2023 11:06:59 -0500 Subject: [PATCH 1/4] Update to sd-script latest update --- README-ja.md | 5 +- README.md | 41 +- fine_tune_README_ja.md | 477 +----- gen_img_diffusers.py | 28 +- library/extract_lora_gui.py | 17 +- library/model_util.py | 2216 +++++++++++--------------- library/train_util.py | 10 +- networks/check_lora_weights.py | 2 +- networks/extract_lora_from_models.py | 50 +- networks/lora.py | 18 +- networks/merge_lora.py | 16 +- networks/resize_lora.py | 99 +- networks/svd_merge_lora.py | 45 +- tools/lycoris_locon_extract.py | 115 ++ train_README-ja.md | 359 ++++- train_db_README-ja.md | 301 +--- train_network.py | 20 +- train_network_README-ja.md | 150 +- train_textual_inversion.py | 7 +- train_ti_README-ja.md | 80 +- 20 files changed, 1886 insertions(+), 2170 deletions(-) create mode 100644 tools/lycoris_locon_extract.py diff --git a/README-ja.md b/README-ja.md index 064464c..47aaf16 100644 --- a/README-ja.md +++ b/README-ja.md @@ -16,9 +16,10 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma 当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。 +* [学習について、共通編](./train_README-ja.md) : データ整備やオプションなど + * [データセット設定](./config_README-ja.md) * [DreamBoothの学習について](./train_db_README-ja.md) * [fine-tuningのガイド](./fine_tune_README_ja.md): -BLIPによるキャプショニングと、DeepDanbooruまたはWD14 taggerによるタグ付けを含みます * [LoRAの学習について](./train_network_README-ja.md) * [Textual Inversionの学習について](./train_ti_README-ja.md) * note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e) @@ -131,6 +132,8 @@ pip install --use-pep517 --upgrade -r requirements.txt LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。 +Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。 + ## ライセンス スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。 diff --git a/README.md b/README.md index 24ad0b7..a468a4e 100644 --- a/README.md +++ b/README.md @@ -176,13 +176,25 @@ This will store your a backup file with your current locally installed pip packa ## Change History -* 2023/03/05 (v21.2.0): +* 2023/03/09 (v21.2.0): - Fix issue https://github.com/bmaltais/kohya_ss/issues/335 - Add option to print LoRA trainer command without executing it - Add support for samples during trainin via a new `Sample images config` accordion in the `Training parameters` tab. - Added new `Additional parameters` under the `Advanced Configuration` section of the `Training parameters` tab to allow for the specifications of parameters not handles by the GUI. - Added support for sample as a new Accordion under the `Training parameters` tab. More info about the prompt options can be found here: https://github.com/kohya-ss/sd-scripts/issues/256#issuecomment-1455005709 - - There may be problems due to major changes. If you cannot revert back to a previous version when problems occur (`git checkout `). + - There may be problems due to major changes. If you cannot revert back to the previous version when problems occur, please do not update for a while. + - Minimum metadata (module name, dim, alpha and network_args) is recorded even with `--no_metadata`, issue https://github.com/kohya-ss/sd-scripts/issues/254 + - `train_network.py` supports LoRA for Conv2d-3x3 (extended to conv2d with a kernel size not 1x1). + - Same as a current version of [LoCon](https://github.com/KohakuBlueleaf/LoCon). __Thank you very much KohakuBlueleaf for your help!__ + - LoCon will be enhanced in the future. Compatibility for future versions is not guaranteed. + - Specify `--network_args` option like: `--network_args "conv_dim=4" "conv_alpha=1"` + - [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) version 0.5.0 or later is required to use 'LoRA for Conv2d-3x3' in Stable Diffusion web UI. + - __Stable Diffusion web UI built-in LoRA does not support 'LoRA for Conv2d-3x3' now. Consider carefully whether or not to use it.__ + - Merging/extracting scripts also support LoRA for Conv2d-3x3. + - Free CUDA memory after sample generation to reduce VRAM usage, issue https://github.com/kohya-ss/sd-scripts/issues/260 + - Empty caption doesn't cause error now, issue https://github.com/kohya-ss/sd-scripts/issues/258 + - Fix sample generation is crashing in Textual Inversion training when using templates, or if height/width is not divisible by 8. + - Update documents (Japanese only). - Dependencies are updated, Please [upgrade](#upgrade) the repo. - Add detail dataset config feature by extra config file. Thanks to fur0ut0 for this great contribution! - Documentation is [here](https://github-com.translate.goog/kohya-ss/sd-scripts/blob/main/config_README-ja.md) (only in Japanese currently.) @@ -197,6 +209,31 @@ This will store your a backup file with your current locally installed pip packa - Add `--tokenizer_cache_dir` to each training and generation scripts to cache Tokenizer locally from Diffusers. - Scripts will support offline training/generation after caching. - Support letents upscaling for highres. fix, and VAE batch size in `gen_img_diffusers.py` (no documentation yet.) + + - Sample image generation: + A prompt file might look like this, for example + + ``` + # prompt 1 + masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 + + # prompt 2 + masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 + ``` + + Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. + + * `--n` Negative prompt up to the next option. + * `--w` Specifies the width of the generated image. + * `--h` Specifies the height of the generated image. + * `--d` Specifies the seed of the generated image. + * `--l` Specifies the CFG scale of the generated image. + * `--s` Specifies the number of steps in the generation. + + The prompt weighting such as `( )` and `[ ]` are not working. + + Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. + * 2023/03/05 (v21.1.5): - Add replace underscore with space option to WD14 captioning. Thanks @sALTaccount! - Improve how custom preset is set and handles. diff --git a/fine_tune_README_ja.md b/fine_tune_README_ja.md index 9dcd34a..686947c 100644 --- a/fine_tune_README_ja.md +++ b/fine_tune_README_ja.md @@ -1,6 +1,9 @@ -NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(v1.4/1.5の場合)環境等に対応したfine tuningです。 +NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません) + +[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 + +# 概要 -## 概要 Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。 * CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。 @@ -13,19 +16,24 @@ Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。Nov デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。 -## 追加機能について -### CLIPの出力の変更 +# 追加機能について + +## CLIPの出力の変更 + プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。 元のまま、最後の層の出力を用いることも可能です。 + ※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。 -### 正方形以外の解像度での学習 +## 正方形以外の解像度での学習 + Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。 学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。 機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。 -### トークン長の75から225への拡張 +## トークン長の75から225への拡張 + Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。 ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。 @@ -33,296 +41,67 @@ Stable Diffusionでは最大75トークン(開始・終了を含むと77トー ※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。 -## 環境整備 +# 学習の手順 -このリポジトリの[README](./README-ja.md)を参照してください。 +あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 -## 教師データの用意 - -学習させたい画像データを用意し、任意のフォルダに入れてください。リサイズ等の事前の準備は必要ありません。 -ただし学習解像度よりもサイズが小さい画像については、超解像などで品質を保ったまま拡大しておくことをお勧めします。 - -複数の教師データフォルダにも対応しています。前処理をそれぞれのフォルダに対して実行する形となります。 - -たとえば以下のように画像を格納します。 - -![教師データフォルダのスクショ](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png) - -## 自動キャプショニング -キャプションを使わずタグだけで学習する場合はスキップしてください。 - -また手動でキャプションを用意する場合、キャプションは教師データ画像と同じディレクトリに、同じファイル名、拡張子.caption等で用意してください。各ファイルは1行のみのテキストファイルとします。 - -### BLIPによるキャプショニング - -最新版ではBLIPのダウンロード、重みのダウンロード、仮想環境の追加は不要になりました。そのままで動作します。 - -finetuneフォルダ内のmake_captions.pyを実行します。 - -``` -python finetune\make_captions.py --batch_size <バッチサイズ> <教師データフォルダ> -``` - -バッチサイズ8、教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。 - -``` -python finetune\make_captions.py --batch_size 8 ..\train_data -``` - -キャプションファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.captionで作成されます。 - -batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。 -max_lengthオプションでキャプションの最大長を指定できます。デフォルトは75です。モデルをトークン長225で学習する場合には長くしても良いかもしれません。 -caption_extensionオプションでキャプションの拡張子を変更できます。デフォルトは.captionです(.txtにすると後述のDeepDanbooruと競合します)。 - -複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。 - -なお、推論にランダム性があるため、実行するたびに結果が変わります。固定する場合には--seedオプションで「--seed 42」のように乱数seedを指定してください。 - -その他のオプションは--helpでヘルプをご参照ください(パラメータの意味についてはドキュメントがまとまっていないようで、ソースを見るしかないようです)。 - -デフォルトでは拡張子.captionでキャプションファイルが生成されます。 - -![captionが生成されたフォルダ](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png) - -たとえば以下のようなキャプションが付きます。 - -![キャプションと画像](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png) - -## DeepDanbooruによるタグ付け -danbooruタグのタグ付け自体を行わない場合は「キャプションとタグ情報の前処理」に進んでください。 - -タグ付けはDeepDanbooruまたはWD14Taggerで行います。WD14Taggerのほうが精度が良いようです。WD14Taggerでタグ付けする場合は、次の章へ進んでください。 - -### 環境整備 -DeepDanbooru https://github.com/KichangKim/DeepDanbooru を作業フォルダにcloneしてくるか、zipをダウンロードして展開します。私はzipで展開しました。 -またDeepDanbooruのReleasesのページ https://github.com/KichangKim/DeepDanbooru/releases の「DeepDanbooru Pretrained Model v3-20211112-sgd-e28」のAssetsから、deepdanbooru-v3-20211112-sgd-e28.zipをダウンロードしてきてDeepDanbooruのフォルダに展開します。 - -以下からダウンロードします。Assetsをクリックして開き、そこからダウンロードします。 - -![DeepDanbooruダウンロードページ](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png) - -以下のようなこういうディレクトリ構造にしてください - -![DeepDanbooruのディレクトリ構造](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png) - -Diffusersの環境に必要なライブラリをインストールします。DeepDanbooruのフォルダに移動してインストールします(実質的にはtensorflow-ioが追加されるだけだと思います)。 - -``` -pip install -r requirements.txt -``` - -続いてDeepDanbooru自体をインストールします。 - -``` -pip install . -``` - -以上でタグ付けの環境整備は完了です。 - -### タグ付けの実施 -DeepDanbooruのフォルダに移動し、deepdanbooruを実行してタグ付けを行います。 - -``` -deepdanbooru evaluate <教師データフォルダ> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt -``` - -教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。 - -``` -deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt -``` - -タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。1件ずつ処理されるためわりと遅いです。 - -複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。 - -以下のように生成されます。 - -![DeepDanbooruの生成ファイル](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png) - -こんな感じにタグが付きます(すごい情報量……)。 - -![DeepDanbooruタグと画像](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png) - -## WD14Taggerによるタグ付け -DeepDanbooruの代わりにWD14Taggerを用いる手順です。 - -Automatic1111氏のWebUIで使用しているtaggerを利用します。こちらのgithubページ(https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。 - -最初の環境整備で必要なモジュールはインストール済みです。また重みはHugging Faceから自動的にダウンロードしてきます。 - -### タグ付けの実施 -スクリプトを実行してタグ付けを行います。 -``` -python tag_images_by_wd14_tagger.py --batch_size <バッチサイズ> <教師データフォルダ> -``` - -教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。 -``` -python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data -``` - -初回起動時にはモデルファイルがwd14_tagger_modelフォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。以下のようになります。 - -![ダウンロードされたファイル](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png) - -タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。 - -![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) - -![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) - -threshオプションで、判定されたタグのconfidence(確信度)がいくつ以上でタグをつけるかが指定できます。デフォルトはWD14Taggerのサンプルと同じ0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。 -batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。caption_extensionオプションでタグファイルの拡張子を変更できます。デフォルトは.txtです。 -model_dirオプションでモデルの保存先フォルダを指定できます。 -またforce_downloadオプションを指定すると保存先フォルダがあってもモデルを再ダウンロードします。 - -複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。 - -## キャプションとタグ情報の前処理 - -スクリプトから処理しやすいようにキャプションとタグをメタデータとしてひとつのファイルにまとめます。 - -### キャプションの前処理 - -キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してください(キャプションを学習に使わない場合は実行不要です)(実際は1行で記述します、以下同様)。 - -``` -python merge_captions_to_metadata.py <教師データフォルダ> -  --in_json <読み込むメタデータファイル名> - <メタデータファイル名> -``` - -メタデータファイル名は任意の名前です。 -教師データがtrain_data、読み込むメタデータファイルなし、メタデータファイルがmeta_cap.jsonの場合、以下のようになります。 - -``` -python merge_captions_to_metadata.py train_data meta_cap.json -``` - -caption_extensionオプションでキャプションの拡張子を指定できます。 - -複数の教師データフォルダがある場合には、full_path引数を指定してください(メタデータにフルパスで情報を持つようになります)。そして、それぞれのフォルダに対して実行してください。 - -``` -python merge_captions_to_metadata.py --full_path - train_data1 meta_cap1.json -python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json - train_data2 meta_cap2.json -``` - -in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。 - -__※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__ - -### タグの前処理 - -同様にタグもメタデータにまとめます(タグを学習に使わない場合は実行不要です)。 -``` -python merge_dd_tags_to_metadata.py <教師データフォルダ> - --in_json <読み込むメタデータファイル名> - <書き込むメタデータファイル名> -``` - -先と同じディレクトリ構成で、meta_cap.jsonを読み、meta_cap_dd.jsonに書きだす場合、以下となります。 -``` -python merge_dd_tags_to_metadata.py train_data --in_json meta_cap.json meta_cap_dd.json -``` - -複数の教師データフォルダがある場合には、full_path引数を指定してください。そして、それぞれのフォルダに対して実行してください。 - -``` -python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json - train_data1 meta_cap_dd1.json -python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json - train_data2 meta_cap_dd2.json -``` - -in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。 - -__※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__ - -### キャプションとタグのクリーニング -ここまででメタデータファイルにキャプションとDeepDanbooruのタグがまとめられています。ただ自動キャプショニングにしたキャプションは表記ゆれなどがあり微妙(※)ですし、タグにはアンダースコアが含まれていたりratingが付いていたりしますので(DeepDanbooruの場合)、エディタの置換機能などを用いてキャプションとタグのクリーニングをしたほうがいいでしょう。 - -※たとえばアニメ絵の少女を学習する場合、キャプションにはgirl/girls/woman/womenなどのばらつきがあります。また「anime girl」なども単に「girl」としたほうが適切かもしれません。 - -クリーニング用のスクリプトが用意してありますので、スクリプトの内容を状況に応じて編集してお使いください。 - -(教師データフォルダの指定は不要になりました。メタデータ内の全データをクリーニングします。) - -``` -python clean_captions_and_tags.py <読み込むメタデータファイル名> <書き込むメタデータファイル名> -``` - ---in_jsonは付きませんのでご注意ください。たとえば次のようになります。 - -``` -python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json -``` - -以上でキャプションとタグの前処理は完了です。 - -## latentsの事前取得 - -学習を高速に進めるためあらかじめ画像の潜在表現を取得しディスクに保存しておきます。あわせてbucketing(教師データをアスペクト比に応じて分類する)を行います。 - -作業フォルダで以下のように入力してください。 -``` -python prepare_buckets_latents.py <教師データフォルダ> - <読み込むメタデータファイル名> <書き込むメタデータファイル名> - - --batch_size <バッチサイズ> - --max_resolution <解像度 幅,高さ> - --mixed_precision <精度> -``` - -モデルがmodel.ckpt、バッチサイズ4、学習解像度は512\*512、精度no(float32)で、meta_clean.jsonからメタデータを読み込み、meta_lat.jsonに書き込む場合、以下のようになります。 - -``` -python prepare_buckets_latents.py - train_data meta_clean.json meta_lat.json model.ckpt - --batch_size 4 --max_resolution 512,512 --mixed_precision no -``` - -教師データフォルダにnumpyのnpz形式でlatentsが保存されます。 - -Stable Diffusion 2.0のモデルを読み込む場合は--v2オプションを指定してください(--v_parameterizationは不要です)。 - -解像度の最小サイズを--min_bucket_resoオプションで、最大サイズを--max_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。たとえば最小サイズに384を指定すると、256\*1024や320\*768などの解像度は使わなくなります。 -解像度を768\*768のように大きくした場合、最大サイズに1280などを指定すると良いでしょう。 - ---flip_augオプションを指定すると左右反転のaugmentation(データ拡張)を行います。疑似的にデータ量を二倍に増やすことができますが、データが左右対称でない場合に指定すると(例えばキャラクタの外見、髪型など)学習がうまく行かなくなります。 -(反転した画像についてもlatentsを取得し、\*\_flip.npzファイルを保存する単純な実装です。fline_tune.pyには特にオプション指定は必要ありません。\_flip付きのファイルがある場合、flip付き・なしのファイルを、ランダムに読み込みます。) - -バッチサイズはVRAM 12GBでももう少し増やせるかもしれません。 -解像度は64で割り切れる数字で、"幅,高さ"で指定します。解像度はfine tuning時のメモリサイズに直結します。VRAM 12GBでは512,512が限界と思われます(※)。16GBなら512,704や512,768まで上げられるかもしれません。なお256,256等にしてもVRAM 8GBでは厳しいようです(パラメータやoptimizerなどは解像度に関係せず一定のメモリが必要なため)。 - -※batch size 1の学習で12GB VRAM、640,640で動いたとの報告もありました。 - -以下のようにbucketingの結果が表示されます。 - -![bucketingの結果](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png) - -複数の教師データフォルダがある場合には、full_path引数を指定してください。そして、それぞれのフォルダに対して実行してください。 -``` -python prepare_buckets_latents.py --full_path - train_data1 meta_clean.json meta_lat1.json model.ckpt - --batch_size 4 --max_resolution 512,512 --mixed_precision no - -python prepare_buckets_latents.py --full_path - train_data2 meta_lat1.json meta_lat2.json model.ckpt - --batch_size 4 --max_resolution 512,512 --mixed_precision no - -``` -読み込み元と書き込み先を同じにすることも可能ですが別々の方が安全です。 - -__※引数を都度書き換えて、別のメタデータファイルに書き込むと安全です。__ +## データの準備 +[学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。 ## 学習の実行 -たとえば以下のように実行します。以下は省メモリ化のための設定です。 +たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。 + +``` +accelerate launch --num_cpu_threads_per_process 1 fine_tune.py + --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> + --output_dir=<学習したモデルの出力先フォルダ> + --output_name=<学習したモデル出力時のファイル名> + --dataset_config=<データ準備で作成した.tomlファイル> + --save_model_as=safetensors + --learning_rate=5e-6 --max_train_steps=10000 + --use_8bit_adam --xformers --gradient_checkpointing + --mixed_precision=fp16 +``` + +`num_cpu_threads_per_process` には通常は1を指定するとよいようです。 + +`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。 + +`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。 + +`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。 + +学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。 + +省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。 + +オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。 + +`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 + +ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。 + +### よく使われるオプションについて + +以下の場合にはオプションに関するドキュメントを参照してください。 + +- Stable Diffusion 2.xまたはそこからの派生モデルを学習する +- clip skipを2以上を前提としたモデルを学習する +- 75トークンを超えたキャプションで学習する + +### バッチサイズについて + +モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。 + +### 学習率について + +1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。 + +### 以前の形式のデータセット指定をした場合のコマンドライン + +解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。 + ``` accelerate launch --num_cpu_threads_per_process 1 fine_tune.py --pretrained_model_name_or_path=model.ckpt @@ -336,76 +115,7 @@ accelerate launch --num_cpu_threads_per_process 1 fine_tune.py --save_every_n_epochs=4 ``` -accelerateのnum_cpu_threads_per_processには通常は1を指定するとよいようです。 - -pretrained_model_name_or_pathに学習対象のモデルを指定します(Stable DiffusionのcheckpointかDiffusersのモデル)。Stable Diffusionのcheckpointは.ckptと.safetensorsに対応しています(拡張子で自動判定)。 - -in_jsonにlatentをキャッシュしたときのメタデータファイルを指定します。 - -train_data_dirに教師データのフォルダを、output_dirに学習後のモデルの出力先フォルダを指定します。 - -shuffle_captionを指定すると、キャプション、タグをカンマ区切りされた単位でシャッフルして学習します(Waifu Diffusion v1.3で行っている手法です)。 -(先頭のトークンのいくつかをシャッフルせずに固定できます。その他のオプションのkeep_tokensをご覧ください。) - -train_batch_sizeにバッチサイズを指定します。VRAM 12GBでは1か2程度を指定してください。解像度によっても指定可能な数は変わってきます。 -学習に使用される実際のデータ量は「バッチサイズ×ステップ数」です。バッチサイズを増やした時には、それに応じてステップ数を下げることが可能です。 - -learning_rateに学習率を指定します。たとえばWaifu Diffusion v1.3は5e-6のようです。 -max_train_stepsにステップ数を指定します。 - -use_8bit_adamを指定すると8-bit Adam Optimizerを使用します。省メモリ化、高速化されますが精度は下がる可能性があります。 - -xformersを指定するとCrossAttentionを置換して省メモリ化、高速化します。 -※11/9時点ではfloat32の学習ではxformersがエラーになるため、bf16/fp16を使うか、代わりにmem_eff_attnを指定して省メモリ版CrossAttentionを使ってください(速度はxformersに劣ります)。 - -gradient_checkpointingで勾配の途中保存を有効にします。速度は遅くなりますが使用メモリ量が減ります。 - -mixed_precisionで混合精度を使うか否かを指定します。"fp16"または"bf16"を指定すると省メモリになりますが精度は劣ります。 -"fp16"と"bf16"は使用メモリ量はほぼ同じで、bf16の方が学習結果は良くなるとの話もあります(試した範囲ではあまり違いは感じられませんでした)。 -"no"を指定すると使用しません(float32になります)。 - -※bf16で学習したcheckpointをAUTOMATIC1111氏のWeb UIで読み込むとエラーになるようです。これはデータ型のbfloat16がWeb UIのモデルsafety checkerでエラーとなるためのようです。save_precisionオプションを指定してfp16またはfloat32形式で保存してください。またはsafetensors形式で保管しても良さそうです。 - -save_every_n_epochsを指定するとそのエポックだけ経過するたびに学習中のモデルを保存します。 - -### Stable Diffusion 2.0対応 -Hugging Faceのstable-diffusion-2-baseを使う場合は--v2オプションを、stable-diffusion-2または768-v-ema.ckptを使う場合は--v2と--v_parameterizationの両方のオプションを指定してください。 - -### メモリに余裕がある場合に精度や速度を上げる -まずgradient_checkpointingを外すと速度が上がります。ただし設定できるバッチサイズが減りますので、精度と速度のバランスを見ながら設定してください。 - -バッチサイズを増やすと速度、精度が上がります。メモリが足りる範囲で、1データ当たりの速度を確認しながら増やしてください(メモリがぎりぎりになるとかえって速度が落ちることがあります)。 - -### 使用するCLIP出力の変更 -clip_skipオプションに2を指定すると、後ろから二番目の層の出力を用います。1またはオプション省略時は最後の層を用います。 -学習したモデルはAutomatic1111氏のWeb UIで推論できるはずです。 - -※SD2.0はデフォルトで後ろから二番目の層を使うため、SD2.0の学習では指定しないでください。 - -学習対象のモデルがもともと二番目の層を使うように学習されている場合は、2を指定するとよいでしょう。 - -そうではなく最後の層を使用していた場合はモデル全体がそれを前提に学習されています。そのため改めて二番目の層を使用して学習すると、望ましい学習結果を得るにはある程度の枚数の教師データ、長めの学習が必要になるかもしれません。 - -### トークン長の拡張 -max_token_lengthに150または225を指定することでトークン長を拡張して学習できます。 -学習したモデルはAutomatic1111氏のWeb UIで推論できるはずです。 - -clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。 - -### 学習ログの保存 -logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。 - -たとえば--logging_dir=logsと指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。 -また--log_prefixオプションを指定すると、日時の前に指定した文字列が追加されます。「--logging_dir=logs --log_prefix=fine_tune_style1」などとして識別用にお使いください。 - -TensorBoardでログを確認するには、別のコマンドプロンプトを開き、作業フォルダで以下のように入力します(tensorboardはDiffusersのインストール時にあわせてインストールされると思いますが、もし入っていないならpip install tensorboardで入れてください)。 -``` -tensorboard --logdir=logs -``` - -### Hypernetworkの学習 -別の記事で解説予定です。 - + -### その他のオプション +# fine tuning特有のその他の主なオプション -#### keep_tokens -数値を指定するとキャプションの先頭から、指定した数だけのトークン(カンマ区切りの文字列)をシャッフルせず固定します。 +すべてのオプションについては別文書を参照してください。 -キャプションとタグが両方ある場合、学習時のプロンプトは「キャプション,タグ1,タグ2……」のように連結されますので、「--keep_tokens=1」とすれば、学習時にキャプションが必ず先頭に来るようになります。 - -#### dataset_repeats -データセットの枚数が極端に少ない場合、epochがすぐに終わってしまうため(epochの区切りで少し時間が掛かります)、数値を指定してデータを何倍かしてepochを長めにしてください。 - -#### train_text_encoder +## `train_text_encoder` Text Encoderも学習対象とします。メモリ使用量が若干増加します。 通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。 -#### save_precision -checkpoint保存時のデータ形式をfloat、fp16、bf16から指定できます(未指定時は学習中のデータ形式と同じ)。ディスク容量が節約できますがモデルによる生成結果は変わってきます。またfloatやfp16を指定すると、1111氏のWeb UIでも読めるようになるはずです。 - -※VAEについては元のcheckpointのデータ形式のままになりますので、fp16でもモデルサイズが2GB強まで小さくならない場合があります。 - -#### save_model_as -モデルの保存形式を指定します。ckpt、safetensors、diffusers、diffusers_safetensorsのいずれかを指定してください。 - -Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。 - -#### use_safetensors -このオプションを指定するとsafetensors形式でcheckpointを保存します。保存形式はデフォルト(読み込んだ形式と同じ)になります。 - -#### save_stateとresume -save_stateオプションで、途中保存時および最終保存時に、checkpointに加えてoptimizer等の学習状態をフォルダに保存します。これにより中断してから学習再開したときの精度低下が避けられます(optimizerは状態を持ちながら最適化をしていくため、その状態がリセットされると再び初期状態から最適化を行わなくてはなりません)。なお、Accelerateの仕様でステップ数は保存されません。 - -スクリプト起動時、resumeオプションで状態の保存されたフォルダを指定すると再開できます。 - -学習状態は一回の保存あたり5GB程度になりますのでディスク容量にご注意ください。 - -#### gradient_accumulation_steps -指定したステップ数だけまとめて勾配を更新します。バッチサイズを増やすのと同様の効果がありますが、メモリを若干消費します。 - -※Accelerateの仕様で学習モデルが複数の場合には対応していないとのことですので、Text Encoderを学習対象にして、このオプションに2以上の値を指定するとエラーになるかもしれません。 - -#### lr_scheduler / lr_warmup_steps -lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。 - -lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。詳細については各自お調べください。 - -#### diffusers_xformers +## `diffusers_xformers` スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。 diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 6bab0bb..8a18517 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1649,10 +1649,11 @@ def get_unweighted_text_embeddings( if pad == eos: # v1 text_input_chunk[:, -1] = text_input[0, -1] else: # v2 - if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある - text_input_chunk[:, -1] = eos - if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD - text_input_chunk[:, 1] = eos + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos if clip_skip is None or clip_skip == 1: text_embedding = pipe.text_encoder(text_input_chunk)[0] @@ -2276,13 +2277,26 @@ def main(args): mask_images = l # 画像サイズにオプション指定があるときはリサイズする - if init_images is not None and args.W is not None and args.H is not None: - print(f"resize img2img source images to {args.W}*{args.H}") - init_images = resize_images(init_images, (args.W, args.H)) + if args.W is not None and args.H is not None: + if init_images is not None: + print(f"resize img2img source images to {args.W}*{args.H}") + init_images = resize_images(init_images, (args.W, args.H)) if mask_images is not None: print(f"resize img2img mask images to {args.W}*{args.H}") mask_images = resize_images(mask_images, (args.W, args.H)) + if networks and mask_images: + # mask を領域情報として流用する、現在は1枚だけ対応 + # TODO 複数のnetwork classの混在時の考慮 + print("use mask as region") + # import cv2 + # for i in range(3): + # cv2.imshow("msk", np.array(mask_images[0])[:,:,i]) + # cv2.waitKey() + # cv2.destroyAllWindows() + networks[0].__class__.set_regions(networks, np.array(mask_images[0])) + mask_images = None + prev_image = None # for VGG16 guided if args.guide_image_path is not None: print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 91897e4..7c991a2 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -22,6 +22,7 @@ def extract_lora( save_precision, dim, v2, + conv_dim, ): # Check for caption_text_input if model_tuned == '': @@ -49,6 +50,8 @@ def extract_lora( run_cmd += f' --model_org "{model_org}"' run_cmd += f' --model_tuned "{model_tuned}"' run_cmd += f' --dim {dim}' + if conv_dim > 0: + run_cmd += f' --conv_dim {conv_dim}' if v2: run_cmd += f' --v2' @@ -71,7 +74,7 @@ def gradio_extract_lora_tab(): gr.Markdown( 'This utility can extract a LoRA network from a finetuned model.' ) - lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) + lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False) model_ext_name = gr.Textbox(value='Model types', visible=False) @@ -133,7 +136,15 @@ def gradio_extract_lora_tab(): maximum=1024, label='Network Dimension', value=128, - step=4, + step=1, + interactive=True, + ) + conv_dim = gr.Slider( + minimum=0, + maximum=1024, + label='Conv Dimension', + value=0, + step=1, interactive=True, ) v2 = gr.Checkbox(label='v2', value=False, interactive=True) @@ -142,6 +153,6 @@ def gradio_extract_lora_tab(): extract_button.click( extract_lora, - inputs=[model_tuned, model_org, save_to, save_precision, dim, v2], + inputs=[model_tuned, model_org, save_to, save_precision, dim, v2, conv_dim], show_progress=False, ) diff --git a/library/model_util.py b/library/model_util.py index 53f51c7..d1020c0 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -4,13 +4,8 @@ import math import os import torch -from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig -from diffusers import ( - AutoencoderKL, - DDIMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from safetensors.torch import load_file, save_file # DiffUsers版StableDiffusionのモデルパラメータ @@ -41,8 +36,8 @@ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] V2_UNET_PARAMS_CONTEXT_DIM = 1024 # Diffusersの設定を読み込むための参照モデル -DIFFUSERS_REF_MODEL_ID_V1 = 'runwayml/stable-diffusion-v1-5' -DIFFUSERS_REF_MODEL_ID_V2 = 'stabilityai/stable-diffusion-2-1' +DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" +DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" # region StableDiffusion->Diffusersの変換コード @@ -50,853 +45,596 @@ DIFFUSERS_REF_MODEL_ID_V2 = 'stabilityai/stable-diffusion-2-1' def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return '.'.join(path.split('.')[n_shave_prefix_segments:]) - else: - return '.'.join(path.split('.')[:n_shave_prefix_segments]) + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace('in_layers.0', 'norm1') - new_item = new_item.replace('in_layers.2', 'conv1') + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") - new_item = new_item.replace('out_layers.0', 'norm2') - new_item = new_item.replace('out_layers.3', 'conv2') + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") - new_item = new_item.replace('emb_layers.1', 'time_emb_proj') - new_item = new_item.replace('skip_connection', 'conv_shortcut') + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - new_item = new_item.replace('nin_shortcut', 'conv_shortcut') - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - new_item = new_item.replace('norm.weight', 'group_norm.weight') - new_item = new_item.replace('norm.bias', 'group_norm.bias') + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace('q.weight', 'query.weight') - new_item = new_item.replace('q.bias', 'query.bias') + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") - new_item = new_item.replace('k.weight', 'key.weight') - new_item = new_item.replace('k.bias', 'key.bias') + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") - new_item = new_item.replace('v.weight', 'value.weight') - new_item = new_item.replace('v.bias', 'value.bias') + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") - new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def assign_to_checkpoint( - paths, - checkpoint, - old_checkpoint, - attention_paths_to_split=None, - additional_replacements=None, - config=None, + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. - Assigns the weights to the new checkpoint. - """ - assert isinstance( - paths, list - ), "Paths should be a list of dicts containing 'old' and 'new' keys." + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 - target_shape = ( - (-1, channels) if len(old_tensor.shape) == 3 else (-1) - ) + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - num_heads = old_tensor.shape[0] // config['num_head_channels'] // 3 + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - old_tensor = old_tensor.reshape( - (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] - ) - query, key, value = old_tensor.split(channels // num_heads, dim=1) + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) - checkpoint[path_map['query']] = query.reshape(target_shape) - checkpoint[path_map['key']] = key.reshape(target_shape) - checkpoint[path_map['value']] = value.reshape(target_shape) + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) - for path in paths: - new_path = path['new'] + for path in paths: + new_path = path["new"] - # These have already been assigned - if ( - attention_paths_to_split is not None - and new_path in attention_paths_to_split - ): - continue + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue - # Global renaming happens here - new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0') - new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0') - new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1') + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace( - replacement['old'], replacement['new'] - ) + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) - # proj_attn.weight has to be converted from conv 1D to linear - if 'proj_attn.weight' in new_path: - checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path['old']] + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ['query.weight', 'key.weight', 'value.weight'] - for key in keys: - if '.'.join(key.split('.')[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif 'proj_attn.weight' in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] def linear_transformer_to_conv(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ['proj_in.weight', 'proj_out.weight'] - for key in keys: - if '.'.join(key.split('.')[-2:]) in tf_keys: - if checkpoint[key].ndim == 2: - checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) def convert_ldm_unet_checkpoint(v2, checkpoint, config): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ - # extract state_dict for UNet - unet_state_dict = {} - unet_key = 'model.diffusion_model.' - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, '')] = checkpoint.pop(key) + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - new_checkpoint = {} + new_checkpoint = {} - new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict[ - 'time_embed.0.weight' - ] - new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict[ - 'time_embed.0.bias' - ] - new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict[ - 'time_embed.2.weight' - ] - new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict[ - 'time_embed.2.bias' + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - new_checkpoint['conv_in.weight'] = unet_state_dict[ - 'input_blocks.0.0.weight' - ] - new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias'] + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) - new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight'] - new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias'] - new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight'] - new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias'] - - # Retrieves the keys for the input blocks only - num_input_blocks = len( - { - '.'.join(layer.split('.')[:2]) - for layer in unet_state_dict - if 'input_blocks' in layer - } + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) - input_blocks = { - layer_id: [ - key - for key in unet_state_dict - if f'input_blocks.{layer_id}.' in key - ] - for layer_id in range(num_input_blocks) - } - # Retrieves the keys for the middle blocks only - num_middle_blocks = len( - { - '.'.join(layer.split('.')[:2]) - for layer in unet_state_dict - if 'middle_block' in layer - } - ) - middle_blocks = { - layer_id: [ - key - for key in unet_state_dict - if f'middle_block.{layer_id}.' in key - ] - for layer_id in range(num_middle_blocks) - } + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) - # Retrieves the keys for the output blocks only - num_output_blocks = len( - { - '.'.join(layer.split('.')[:2]) - for layer in unet_state_dict - if 'output_blocks' in layer - } - ) - output_blocks = { - layer_id: [ - key - for key in unet_state_dict - if f'output_blocks.{layer_id}.' in key - ] - for layer_id in range(num_output_blocks) - } + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config['layers_per_block'] + 1) - layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1) + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - resnets = [ - key - for key in input_blocks[i] - if f'input_blocks.{i}.0' in key - and f'input_blocks.{i}.0.op' not in key + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" ] - attentions = [ - key for key in input_blocks[i] if f'input_blocks.{i}.1' in key + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" ] - if f'input_blocks.{i}.0.op.weight' in unet_state_dict: - new_checkpoint[ - f'down_blocks.{block_id}.downsamplers.0.conv.weight' - ] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight') - new_checkpoint[ - f'down_blocks.{block_id}.downsamplers.0.conv.bias' - ] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias') + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] - paths = renew_resnet_paths(resnets) + if len(attentions): + paths = renew_attention_paths(attentions) meta_path = { - 'old': f'input_blocks.{i}.0', - 'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}', + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - 'old': f'input_blocks.{i}.1', - 'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, - ) + new_checkpoint[new_path] = unet_state_dict[old_path] - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] + # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する + if v2: + linear_transformer_to_conv(new_checkpoint) - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint( - resnet_0_paths, new_checkpoint, unet_state_dict, config=config - ) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint( - resnet_1_paths, new_checkpoint, unet_state_dict, config=config - ) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'} - assign_to_checkpoint( - attentions_paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - for i in range(num_output_blocks): - block_id = i // (config['layers_per_block'] + 1) - layer_in_block_id = i % (config['layers_per_block'] + 1) - output_block_layers = [ - shave_segments(name, 2) for name in output_blocks[i] - ] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split('.')[0], shave_segments( - layer, 1 - ) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [ - key - for key in output_blocks[i] - if f'output_blocks.{i}.0' in key - ] - attentions = [ - key - for key in output_blocks[i] - if f'output_blocks.{i}.1' in key - ] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = { - 'old': f'output_blocks.{i}.0', - 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - # オリジナル: - # if ["conv.weight", "conv.bias"] in output_block_list.values(): - # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - - # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが - for l in output_block_list.values(): - l.sort() - - if ['conv.bias', 'conv.weight'] in output_block_list.values(): - index = list(output_block_list.values()).index( - ['conv.bias', 'conv.weight'] - ) - new_checkpoint[ - f'up_blocks.{block_id}.upsamplers.0.conv.bias' - ] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias'] - new_checkpoint[ - f'up_blocks.{block_id}.upsamplers.0.conv.weight' - ] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight'] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - 'old': f'output_blocks.{i}.1', - 'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, - ) - else: - resnet_0_paths = renew_resnet_paths( - output_block_layers, n_shave_prefix_segments=1 - ) - for path in resnet_0_paths: - old_path = '.'.join(['output_blocks', str(i), path['old']]) - new_path = '.'.join( - [ - 'up_blocks', - str(block_id), - 'resnets', - str(layer_in_block_id), - path['new'], - ] - ) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する - if v2: - linear_transformer_to_conv(new_checkpoint) - - return new_checkpoint + return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = 'first_stage_model.' - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, '')] = checkpoint.get(key) - # if len(vae_state_dict) == 0: - # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict - # vae_state_dict = checkpoint + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint - new_checkpoint = {} + new_checkpoint = {} - new_checkpoint['encoder.conv_in.weight'] = vae_state_dict[ - 'encoder.conv_in.weight' - ] - new_checkpoint['encoder.conv_in.bias'] = vae_state_dict[ - 'encoder.conv_in.bias' - ] - new_checkpoint['encoder.conv_out.weight'] = vae_state_dict[ - 'encoder.conv_out.weight' - ] - new_checkpoint['encoder.conv_out.bias'] = vae_state_dict[ - 'encoder.conv_out.bias' - ] - new_checkpoint['encoder.conv_norm_out.weight'] = vae_state_dict[ - 'encoder.norm_out.weight' - ] - new_checkpoint['encoder.conv_norm_out.bias'] = vae_state_dict[ - 'encoder.norm_out.bias' + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] - new_checkpoint['decoder.conv_in.weight'] = vae_state_dict[ - 'decoder.conv_in.weight' - ] - new_checkpoint['decoder.conv_in.bias'] = vae_state_dict[ - 'decoder.conv_in.bias' - ] - new_checkpoint['decoder.conv_out.weight'] = vae_state_dict[ - 'decoder.conv_out.weight' - ] - new_checkpoint['decoder.conv_out.bias'] = vae_state_dict[ - 'decoder.conv_out.bias' - ] - new_checkpoint['decoder.conv_norm_out.weight'] = vae_state_dict[ - 'decoder.norm_out.weight' - ] - new_checkpoint['decoder.conv_norm_out.bias'] = vae_state_dict[ - 'decoder.norm_out.bias' - ] + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] - new_checkpoint['quant_conv.weight'] = vae_state_dict['quant_conv.weight'] - new_checkpoint['quant_conv.bias'] = vae_state_dict['quant_conv.bias'] - new_checkpoint['post_quant_conv.weight'] = vae_state_dict[ - 'post_quant_conv.weight' - ] - new_checkpoint['post_quant_conv.bias'] = vae_state_dict[ - 'post_quant_conv.bias' - ] + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len( - { - '.'.join(layer.split('.')[:3]) - for layer in vae_state_dict - if 'encoder.down' in layer - } - ) - down_blocks = { - layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] - for layer_id in range(num_down_blocks) - } + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len( - { - '.'.join(layer.split('.')[:3]) - for layer in vae_state_dict - if 'decoder.up' in layer - } - ) - up_blocks = { - layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] - for layer_id in range(num_up_blocks) - } + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - for i in range(num_down_blocks): - resnets = [ - key - for key in down_blocks[i] - if f'down.{i}' in key and f'down.{i}.downsample' not in key - ] - - if f'encoder.down.{i}.downsample.conv.weight' in vae_state_dict: - new_checkpoint[ - f'encoder.down_blocks.{i}.downsamplers.0.conv.weight' - ] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.weight') - new_checkpoint[ - f'encoder.down_blocks.{i}.downsamplers.0.conv.bias' - ] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.bias') - - paths = renew_vae_resnet_paths(resnets) - meta_path = { - 'old': f'down.{i}.block', - 'new': f'down_blocks.{i}.resnets', - } - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_resnets = [key for key in vae_state_dict if 'encoder.mid.block' in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [ - key for key in mid_resnets if f'encoder.mid.block_{i}' in key - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = { - 'old': f'mid.block_{i}', - 'new': f'mid_block.resnets.{i - 1}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_attentions = [ - key for key in vae_state_dict if 'encoder.mid.attn' in key - ] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key - for key in up_blocks[block_id] - if f'up.{block_id}' in key and f'up.{block_id}.upsample' not in key - ] - - if f'decoder.up.{block_id}.upsample.conv.weight' in vae_state_dict: - new_checkpoint[ - f'decoder.up_blocks.{i}.upsamplers.0.conv.weight' - ] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.weight'] - new_checkpoint[ - f'decoder.up_blocks.{i}.upsamplers.0.conv.bias' - ] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.bias'] - - paths = renew_vae_resnet_paths(resnets) - meta_path = { - 'old': f'up.{block_id}.block', - 'new': f'up_blocks.{i}.resnets', - } - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_resnets = [key for key in vae_state_dict if 'decoder.mid.block' in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [ - key for key in mid_resnets if f'decoder.mid.block_{i}' in key - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = { - 'old': f'mid.block_{i}', - 'new': f'mid_block.resnets.{i - 1}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_attentions = [ - key for key in vae_state_dict if 'decoder.mid.attn' in key - ] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint def create_unet_diffusers_config(v2): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # unet_params = original_config.model.params.unet_config.params + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params - block_out_channels = [ - UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT - ] + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = ( - 'CrossAttnDownBlock2D' - if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS - else 'DownBlock2D' - ) - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = ( - 'CrossAttnUpBlock2D' - if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS - else 'UpBlock2D' - ) - up_block_types.append(block_type) - resolution //= 2 + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 - config = dict( - sample_size=UNET_PARAMS_IMAGE_SIZE, - in_channels=UNET_PARAMS_IN_CHANNELS, - out_channels=UNET_PARAMS_OUT_CHANNELS, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM - if not v2 - else V2_UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS - if not v2 - else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - ) + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + ) - return config + return config def create_vae_diffusers_config(): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # vae_params = original_config.model.params.first_stage_config.params.ddconfig - # _ = original_config.model.params.first_stage_config.params.embed_dim - block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] - down_block_types = ['DownEncoderBlock2D'] * len(block_out_channels) - up_block_types = ['UpDecoderBlock2D'] * len(block_out_channels) + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - config = dict( - sample_size=VAE_PARAMS_RESOLUTION, - in_channels=VAE_PARAMS_IN_CHANNELS, - out_channels=VAE_PARAMS_OUT_CH, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=VAE_PARAMS_Z_CHANNELS, - layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, - ) - return config + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config def convert_ldm_clip_checkpoint_v1(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: - if key.startswith('cond_stage_model.transformer'): - text_model_dict[ - key[len('cond_stage_model.transformer.') :] - ] = checkpoint[key] - return text_model_dict + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] + return text_model_dict def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): - # 嫌になるくらい違うぞ! - def convert_key(key): - if not key.startswith('cond_stage_model'): - return None + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None - # common conversion - key = key.replace( - 'cond_stage_model.model.transformer.', 'text_model.encoder.' - ) - key = key.replace('cond_stage_model.model.', 'text_model.') + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") - if 'resblocks' in key: - # resblocks conversion - key = key.replace('.resblocks.', '.layers.') - if '.ln_' in key: - key = key.replace('.ln_', '.layer_norm') - elif '.mlp.' in key: - key = key.replace('.c_fc.', '.fc1.') - key = key.replace('.c_proj.', '.fc2.') - elif '.attn.out_proj' in key: - key = key.replace('.attn.out_proj.', '.self_attn.out_proj.') - elif '.attn.in_proj' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f'unexpected key in SD: {key}') - elif '.positional_embedding' in key: - key = key.replace( - '.positional_embedding', - '.embeddings.position_embedding.weight', - ) - elif '.text_projection' in key: - key = None # 使われない??? - elif '.logit_scale' in key: - key = None # 使われない??? - elif '.token_embedding' in key: - key = key.replace( - '.token_embedding.weight', '.embeddings.token_embedding.weight' - ) - elif '.ln_final' in key: - key = key.replace('.ln_final', '.final_layer_norm') - return key + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif '.attn.out_proj' in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif '.attn.in_proj' in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif '.positional_embedding' in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif '.text_projection' in key: + key = None # 使われない??? + elif '.logit_scale' in key: + key = None # 使われない??? + elif '.token_embedding' in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif '.ln_final' in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - # remove resblocks 23 - if '.resblocks.23.' in key: - continue - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if '.resblocks.23.' in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] - # attnの変換 - for key in keys: - if '.resblocks.23.' in key: - continue - if '.resblocks' in key and '.attn.in_proj_' in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) + # attnの変換 + for key in keys: + if '.resblocks.23.' in key: + continue + if '.resblocks' in key and '.attn.in_proj_' in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) - key_suffix = '.weight' if 'weight' in key else '.bias' - key_pfx = key.replace( - 'cond_stage_model.model.transformer.resblocks.', - 'text_model.encoder.layers.', - ) - key_pfx = key_pfx.replace('_weight', '') - key_pfx = key_pfx.replace('_bias', '') - key_pfx = key_pfx.replace('.attn.in_proj', '.self_attn.') - new_sd[key_pfx + 'q_proj' + key_suffix] = values[0] - new_sd[key_pfx + 'k_proj' + key_suffix] = values[1] - new_sd[key_pfx + 'v_proj' + key_suffix] = values[2] + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # rename or add position_ids - ANOTHER_POSITION_IDS_KEY = ( - 'text_model.encoder.text_model.embeddings.position_ids' - ) - if ANOTHER_POSITION_IDS_KEY in new_sd: - # waifu diffusion v1.4 - position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] - del new_sd[ANOTHER_POSITION_IDS_KEY] - else: - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - - new_sd['text_model.embeddings.position_ids'] = position_ids - return new_sd + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd # endregion @@ -904,643 +642,543 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): # region Diffusers->StableDiffusion の変換コード # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) - def conv_transformer_to_linear(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ['proj_in.weight', 'proj_out.weight'] - for key in keys: - if '.'.join(key.split('.')[-2:]) in tf_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] def convert_unet_state_dict_to_sd(v2, unet_state_dict): - unet_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ('time_embed.0.weight', 'time_embedding.linear_1.weight'), - ('time_embed.0.bias', 'time_embedding.linear_1.bias'), - ('time_embed.2.weight', 'time_embedding.linear_2.weight'), - ('time_embed.2.bias', 'time_embedding.linear_2.bias'), - ('input_blocks.0.0.weight', 'conv_in.weight'), - ('input_blocks.0.0.bias', 'conv_in.bias'), - ('out.0.weight', 'conv_norm_out.weight'), - ('out.0.bias', 'conv_norm_out.bias'), - ('out.2.weight', 'conv_out.weight'), - ('out.2.bias', 'conv_out.bias'), - ] + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ('in_layers.0', 'norm1'), - ('in_layers.2', 'conv1'), - ('out_layers.0', 'norm2'), - ('out_layers.3', 'conv2'), - ('emb_layers.1', 'time_emb_proj'), - ('skip_connection', 'conv_shortcut'), - ] + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] - unet_conversion_map_layer = [] - for i in range(4): - # loop over downblocks/upblocks - - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f'down_blocks.{i}.resnets.{j}.' - sd_down_res_prefix = f'input_blocks.{3*i + j + 1}.0.' - unet_conversion_map_layer.append( - (sd_down_res_prefix, hf_down_res_prefix) - ) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f'down_blocks.{i}.attentions.{j}.' - sd_down_atn_prefix = f'input_blocks.{3*i + j + 1}.1.' - unet_conversion_map_layer.append( - (sd_down_atn_prefix, hf_down_atn_prefix) - ) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f'up_blocks.{i}.resnets.{j}.' - sd_up_res_prefix = f'output_blocks.{3*i + j}.0.' - unet_conversion_map_layer.append( - (sd_up_res_prefix, hf_up_res_prefix) - ) - - if i > 0: - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f'up_blocks.{i}.attentions.{j}.' - sd_up_atn_prefix = f'output_blocks.{3*i + j}.1.' - unet_conversion_map_layer.append( - (sd_up_atn_prefix, hf_up_atn_prefix) - ) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.conv.' - sd_downsample_prefix = f'input_blocks.{3*(i+1)}.0.op.' - unet_conversion_map_layer.append( - (sd_downsample_prefix, hf_downsample_prefix) - ) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' - sd_upsample_prefix = ( - f'output_blocks.{3*i + 2}.{1 if i == 0 else 2}.' - ) - unet_conversion_map_layer.append( - (sd_upsample_prefix, hf_upsample_prefix) - ) - - hf_mid_atn_prefix = 'mid_block.attentions.0.' - sd_mid_atn_prefix = 'middle_block.1.' - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks for j in range(2): - hf_mid_res_prefix = f'mid_block.resnets.{j}.' - sd_mid_res_prefix = f'middle_block.{2*j}.' - unet_conversion_map_layer.append( - (sd_mid_res_prefix, hf_mid_res_prefix) - ) + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - # buyer beware: this is a *brittle* function, - # and correct output requires that all of these pieces interact in - # the exact order in which I have arranged them. - mapping = {k: k for k in unet_state_dict.keys()} - for sd_name, hf_name in unet_conversion_map: - mapping[hf_name] = sd_name - for k, v in mapping.items(): - if 'resnets' in k: - for sd_part, hf_part in unet_conversion_map_resnet: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, hf_part in unet_conversion_map_layer: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - if v2: - conv_transformer_to_linear(new_state_dict) + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - return new_state_dict + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict # ================# # VAE Conversion # # ================# - def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) + return w.reshape(*w.shape, 1, 1) def convert_vae_state_dict(vae_state_dict): - vae_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ('nin_shortcut', 'conv_shortcut'), - ('norm_out', 'conv_norm_out'), - ('mid.attn_1.', 'mid_block.attentions.0.'), - ] + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] - for i in range(4): - # down_blocks have two resnets - for j in range(2): - hf_down_prefix = f'encoder.down_blocks.{i}.resnets.{j}.' - sd_down_prefix = f'encoder.down.{i}.block.{j}.' - vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) - if i < 3: - hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.' - sd_downsample_prefix = f'down.{i}.downsample.' - vae_conversion_map.append( - (sd_downsample_prefix, hf_downsample_prefix) - ) + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' - sd_upsample_prefix = f'up.{3-i}.upsample.' - vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) - # up_blocks have three resnets - # also, up blocks in hf are numbered in reverse from sd - for j in range(3): - hf_up_prefix = f'decoder.up_blocks.{i}.resnets.{j}.' - sd_up_prefix = f'decoder.up.{3-i}.block.{j}.' - vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) - # this part accounts for mid blocks in both the encoder and the decoder - for i in range(2): - hf_mid_res_prefix = f'mid_block.resnets.{i}.' - sd_mid_res_prefix = f'mid.block_{i+1}.' - vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ('norm.', 'group_norm.'), - ('q.', 'query.'), - ('k.', 'key.'), - ('v.', 'value.'), - ('proj_out.', 'proj_attn.'), - ] + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] - mapping = {k: k for k in vae_state_dict.keys()} - for k, v in mapping.items(): - for sd_part, hf_part in vae_conversion_map: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - if 'attentions' in k: - for sd_part, hf_part in vae_conversion_map_attn: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ['q', 'k', 'v', 'proj_out'] - for k, v in new_state_dict.items(): - for weight_name in weights_to_convert: - if f'mid.attn_1.{weight_name}.weight' in k: - # print(f"Reshaping {k} for SD format") - new_state_dict[k] = reshape_weight_for_sd(v) + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) - return new_state_dict + return new_state_dict # endregion # region 自作のモデル読み書きなど - def is_safetensors(path): - return os.path.splitext(path)[1].lower() == '.safetensors' + return os.path.splitext(path)[1].lower() == '.safetensors' def load_checkpoint_with_text_encoder_conversion(ckpt_path): - # text encoderの格納形式が違うモデルに対応する ('text_model'がない) - TEXT_ENCODER_KEY_REPLACEMENTS = [ - ( - 'cond_stage_model.transformer.embeddings.', - 'cond_stage_model.transformer.text_model.embeddings.', - ), - ( - 'cond_stage_model.transformer.encoder.', - 'cond_stage_model.transformer.text_model.encoder.', - ), - ( - 'cond_stage_model.transformer.final_layer_norm.', - 'cond_stage_model.transformer.text_model.final_layer_norm.', - ), - ] + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), + ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), + ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') + ] - if is_safetensors(ckpt_path): - checkpoint = None - state_dict = load_file(ckpt_path, 'cpu') + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path, "cpu") + else: + checkpoint = torch.load(ckpt_path, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: - checkpoint = torch.load(ckpt_path, map_location='cpu') - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - state_dict = checkpoint - checkpoint = None + state_dict = checkpoint + checkpoint = None - key_reps = [] - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - for key in state_dict.keys(): - if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from) :] - key_reps.append((key, new_key)) + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from):] + key_reps.append((key, new_key)) - for key, new_key in key_reps: - state_dict[new_key] = state_dict[key] - del state_dict[key] + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] - return checkpoint, state_dict + return checkpoint, state_dict # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if dtype is not None: - for k, v in state_dict.items(): - if type(v) is torch.Tensor: - state_dict[k] = v.to(dtype) + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if dtype is not None: + for k, v in state_dict.items(): + if type(v) is torch.Tensor: + state_dict[k] = v.to(dtype) - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2) - converted_unet_checkpoint = convert_ldm_unet_checkpoint( - v2, state_dict, unet_config + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) + + unet = UNet2DConditionModel(**unet_config) + info = unet.load_state_dict(converted_unet_checkpoint) + print("loading u-net:", info) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + + vae = AutoencoderKL(**vae_config) + info = vae.load_state_dict(converted_vae_checkpoint) + print("loading vae:", info) + + # convert text_model + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - unet = UNet2DConditionModel(**unet_config) - info = unet.load_state_dict(converted_unet_checkpoint) - print('loading u-net:', info) + logging.set_verbosity_error() # don't show annoying warning + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + logging.set_verbosity_warning() - # Convert the VAE model. - vae_config = create_vae_diffusers_config() - converted_vae_checkpoint = convert_ldm_vae_checkpoint( - state_dict, vae_config - ) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print("loading text encoder:", info) - vae = AutoencoderKL(**vae_config) - info = vae.load_state_dict(converted_vae_checkpoint) - print('loading vae:', info) - - # convert text_model - if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2( - state_dict, 77 - ) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act='gelu', - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type='clip_text_model', - projection_dim=512, - torch_dtype='float32', - transformers_version='4.25.0.dev0', - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - else: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1( - state_dict - ) - text_model = CLIPTextModel.from_pretrained( - 'openai/clip-vit-large-patch14' - ) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print('loading text encoder:', info) - - return text_model, vae, unet + return text_model, vae, unet -def convert_text_encoder_state_dict_to_sd_v2( - checkpoint, make_dummy_weights=False -): - def convert_key(key): - # position_idsの除去 - if '.position_ids' in key: - return None +def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None - # common - key = key.replace('text_model.encoder.', 'transformer.') - key = key.replace('text_model.', '') - if 'layers' in key: - # resblocks conversion - key = key.replace('.layers.', '.resblocks.') - if '.layer_norm' in key: - key = key.replace('.layer_norm', '.ln_') - elif '.mlp.' in key: - key = key.replace('.fc1.', '.c_fc.') - key = key.replace('.fc2.', '.c_proj.') - elif '.self_attn.out_proj' in key: - key = key.replace('.self_attn.out_proj.', '.attn.out_proj.') - elif '.self_attn.' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f'unexpected key in DiffUsers model: {key}') - elif '.position_embedding' in key: - key = key.replace( - 'embeddings.position_embedding.weight', 'positional_embedding' - ) - elif '.token_embedding' in key: - key = key.replace( - 'embeddings.token_embedding.weight', 'token_embedding.weight' - ) - elif 'final_layer_norm' in key: - key = key.replace('final_layer_norm', 'ln_final') - return key + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif '.self_attn.out_proj' in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif '.self_attn.' in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif '.position_embedding' in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif '.token_embedding' in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif 'final_layer_norm' in key: + key = key.replace("final_layer_norm", "ln_final") + return key - keys = list(checkpoint.keys()) - new_sd = {} + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if 'layers' in key and 'q_proj' in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + # 最後の層などを捏造するか + if make_dummy_weights: + print("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる - # attnの変換 - for key in keys: - if 'layers' in key and 'q_proj' in key: - # 三つを結合 - key_q = key - key_k = key.replace('q_proj', 'k_proj') - key_v = key.replace('q_proj', 'v_proj') + # Diffusersに含まれない重みを作っておく + new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd['logit_scale'] = torch.tensor(1) - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace( - 'text_model.encoder.layers.', 'transformer.resblocks.' - ) - new_key = new_key.replace('.self_attn.q_proj.', '.attn.in_proj_') - new_sd[new_key] = value - - # 最後の層などを捏造するか - if make_dummy_weights: - print( - 'make dummy weights for resblock.23, text_projection and logit scale.' - ) - keys = list(new_sd.keys()) - for key in keys: - if key.startswith('transformer.resblocks.22.'): - new_sd[key.replace('.22.', '.23.')] = new_sd[ - key - ].clone() # copyしないとsafetensorsの保存で落ちる - - # Diffusersに含まれない重みを作っておく - new_sd['text_projection'] = torch.ones( - (1024, 1024), - dtype=new_sd[keys[0]].dtype, - device=new_sd[keys[0]].device, - ) - new_sd['logit_scale'] = torch.tensor(1) - - return new_sd + return new_sd -def save_stable_diffusion_checkpoint( - v2, - output_file, - text_encoder, - unet, - ckpt_path, - epochs, - steps, - save_dtype=None, - vae=None, -): - if ckpt_path is not None: - # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion( - ckpt_path - ) - if checkpoint is None: # safetensors または state_dictのckpt - checkpoint = {} - strict = False - else: - strict = True - if 'state_dict' in state_dict: - del state_dict['state_dict'] +def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False else: - # 新しく作る - assert ( - vae is not None - ), 'VAE is required to save a checkpoint without a given checkpoint' - checkpoint = {} - state_dict = {} - strict = False + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] + else: + # 新しく作る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - assert ( - not strict or key in state_dict - ), f'Illegal key in save SD: {key}' - if save_dtype is not None: - v = v.detach().clone().to('cpu').to(save_dtype) - state_dict[key] = v + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v - # Convert the UNet model - unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) - update_sd('model.diffusion_model.', unet_state_dict) + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) - # Convert the text encoder model + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {'state_dict': state_dict} + + if 'epoch' in checkpoint: + epochs += checkpoint['epoch'] + if 'global_step' in checkpoint: + steps += checkpoint['global_step'] + + new_ckpt['epoch'] = epochs + new_ckpt['global_step'] = steps + + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 if v2: - make_dummy = ( - ckpt_path is None - ) # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる - text_enc_dict = convert_text_encoder_state_dict_to_sd_v2( - text_encoder.state_dict(), make_dummy - ) - update_sd('cond_stage_model.model.', text_enc_dict) + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 else: - text_enc_dict = text_encoder.state_dict() - update_sd('cond_stage_model.transformer.', text_enc_dict) + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 - # Convert the VAE - if vae is not None: - vae_dict = convert_vae_state_dict(vae.state_dict()) - update_sd('first_stage_model.', vae_dict) + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {'state_dict': state_dict} - - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] - - new_ckpt['epoch'] = epochs - new_ckpt['global_step'] = steps - - if is_safetensors(output_file): - # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file) - else: - torch.save(new_ckpt, output_file) - - return key_count + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) -def save_diffusers_checkpoint( - v2, - output_dir, - text_encoder, - unet, - pretrained_model_name_or_path, - vae=None, - use_safetensors=False, -): - if pretrained_model_name_or_path is None: - # load default settings for v1/v2 - if v2: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 - else: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 - - scheduler = DDIMScheduler.from_pretrained( - pretrained_model_name_or_path, subfolder='scheduler' - ) - tokenizer = CLIPTokenizer.from_pretrained( - pretrained_model_name_or_path, subfolder='tokenizer' - ) - if vae is None: - vae = AutoencoderKL.from_pretrained( - pretrained_model_name_or_path, subfolder='vae' - ) - - pipeline = StableDiffusionPipeline( - unet=unet, - text_encoder=text_encoder, - vae=vae, - scheduler=scheduler, - tokenizer=tokenizer, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=None, - ) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) - - -VAE_PREFIX = 'first_stage_model.' +VAE_PREFIX = "first_stage_model." def load_vae(vae_id, dtype): - print(f'load VAE: {vae_id}') - if os.path.isdir(vae_id) or not os.path.isfile(vae_id): - # Diffusers local/remote - try: - vae = AutoencoderKL.from_pretrained( - vae_id, subfolder=None, torch_dtype=dtype - ) - except EnvironmentError as e: - print(f'exception occurs in loading vae: {e}') - print("retry with subfolder='vae'") - vae = AutoencoderKL.from_pretrained( - vae_id, subfolder='vae', torch_dtype=dtype - ) - return vae - - # local - vae_config = create_vae_diffusers_config() - - if vae_id.endswith('.bin'): - # SD 1.5 VAE on Huggingface - converted_vae_checkpoint = torch.load(vae_id, map_location='cpu') - else: - # StableDiffusion - vae_model = ( - load_file(vae_id, 'cpu') - if is_safetensors(vae_id) - else torch.load(vae_id, map_location='cpu') - ) - vae_sd = ( - vae_model['state_dict'] if 'state_dict' in vae_model else vae_model - ) - - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd - - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint( - vae_sd, vae_config - ) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) return vae + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) + else torch.load(vae_id, map_location="cpu")) + vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae # endregion -def make_bucket_resolutions( - max_reso, min_size=256, max_size=1024, divisible=64 -): - max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) +def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) - resos = set() + resos = set() - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) - # # make additional resos - # if width >= height and width - divisible >= min_size: - # resos.add((width - divisible, height)) - # resos.add((height, width - divisible)) - # if height >= width and height - divisible >= min_size: - # resos.add((width, height - divisible)) - # resos.add((height - divisible, width)) + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) - size += divisible + size += divisible - resos = list(resos) - resos.sort() - return resos + resos = list(resos) + resos.sort() + return resos if __name__ == '__main__': - resos = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) - aspect_ratios = [w / h for w, h in resos] - print(aspect_ratios) + resos = make_bucket_resolutions((512, 768)) + print(len(resos)) + print(resos) + aspect_ratios = [w / h for w, h in resos] + print(aspect_ratios) - ars = set() - for ar in aspect_ratios: - if ar in ars: - print('error! duplicate ar:', ar) - ars.add(ar) + ars = set() + for ar in aspect_ratios: + if ar in ars: + print("error! duplicate ar:", ar) + ars.add(ar) diff --git a/library/train_util.py b/library/train_util.py index 351d222..6af1abe 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -924,7 +924,9 @@ class FineTuningDataset(BaseDataset): elif tags is not None and len(tags) > 0: caption = caption + ', ' + tags tags_list.append(tags) - assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" + + if caption is None: + caption = "" image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) image_info.image_size = img_md.get('train_resolution') @@ -2207,7 +2209,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v if epoch is None or epoch % args.sample_every_n_epochs != 0: return else: - if steps % args.sample_every_n_steps != 0: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") @@ -2351,6 +2353,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v if negative_prompt is not None: negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) @@ -2394,4 +2398,4 @@ class ImageLoadingDataset(torch.utils.data.Dataset): return (tensor_pil, img_path) -# endregion \ No newline at end of file +# endregion diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 4ee3f57..6bd9ccd 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -21,7 +21,7 @@ def main(file): for key, value in values: value = value.to(torch.float32) - print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") if __name__ == '__main__': diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 84d705c..9f40978 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -45,8 +45,13 @@ def svd(args): text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) # create LoRA network to extract weights: Use dim (rank) as alpha - lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o) - lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t) + if args.conv_dim is None: + kwargs = {} + else: + kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} + + lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -85,13 +90,27 @@ def svd(args): # make LoRA with svd print("calculating by svd") - rank = args.dim lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): + # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = (len(mat.size()) == 4) + kernel_size = None if not conv2d else mat.size()[2:4] + conv2d_3x3 = conv2d and kernel_size != (1, 1) + + rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim + out_dim, in_dim = mat.size()[0:2] + + if args.device: + mat = mat.to(args.device) + # print(mat.size(), mat.device, rank, in_dim, out_dim) + rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim + if conv2d: - mat = mat.squeeze() + if conv2d_3x3: + mat = mat.flatten(start_dim=1) + else: + mat = mat.squeeze() U, S, Vh = torch.linalg.svd(mat) @@ -102,11 +121,18 @@ def svd(args): Vh = Vh[:rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val + # hi_val = torch.quantile(dist, CLAMP_QUANTILE) + # low_val = -hi_val - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) + # U = U.clamp(low_val, hi_val) + # Vh = Vh.clamp(low_val, hi_val) + + if conv2d: + U = U.reshape(out_dim, rank, 1, 1) + Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) + + U = U.to("cuda").contiguous() + Vh = Vh.to("cuda").contiguous() lora_weights[lora_name] = (U, Vh) @@ -124,8 +150,8 @@ def svd(args): weights = lora_weights[lora_name][i] # print(key, i, weights.size(), lora_sd[key].size()) - if len(lora_sd[key].size()) == 4: - weights = weights.unsqueeze(2).unsqueeze(3) + # if len(lora_sd[key].size()) == 4: + # weights = weights.unsqueeze(2).unsqueeze(3) assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" lora_sd[key] = weights @@ -139,7 +165,7 @@ def svd(args): os.makedirs(dir_name, exist_ok=True) # minimum metadata - metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} + metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} lora_network_o.save_weights(args.save_to, save_dtype, metadata) print(f"LoRA weights are saved to: {args.save_to}") @@ -158,6 +184,8 @@ if __name__ == '__main__': parser.add_argument("--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") + parser.add_argument("--conv_dim", type=int, default=None, + help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") args = parser.parse_args() diff --git a/networks/lora.py b/networks/lora.py index 7179baf..c0181c0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -70,7 +70,7 @@ class LoRAModule(torch.nn.Module): if self.region is None: return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - # reginal LoRA + # regional LoRA FIXME same as additional-network extension if x.size()[1] % 77 == 0: # print(f"LoRA for context: {self.lora_name}") self.region = None @@ -107,10 +107,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un network_dim = 4 # default # extract dim/alpha for conv2d, and block dim - conv_dim = int(kwargs.get('conv_dim', network_dim)) - conv_alpha = kwargs.get('conv_alpha', network_alpha) - if conv_alpha is not None: - conv_alpha = float(conv_alpha) + conv_dim = kwargs.get('conv_dim', None) + conv_alpha = kwargs.get('conv_alpha', None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) """ block_dims = kwargs.get("block_dims") @@ -165,7 +169,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa elif 'lora_down' in key: dim = value.size()[0] modules_dim[lora_name] = dim - print(lora_name, value.size(), dim) + # print(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -380,4 +384,4 @@ class LoRANetwork(torch.nn.Module): def set_region(self, region): for lora in self.unet_loras: - lora.set_region(region) \ No newline at end of file + lora.set_region(region) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 09aea7b..09dee4d 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') name_to_module[lora_name] = child_module @@ -80,13 +80,19 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # W <- W + U * D weight = module.weight + # print(module_name, down_weight.size(), up_weight.size()) if len(weight.size()) == 2: # linear weight = weight + ratio * (up_weight @ down_weight) * scale - else: - # conv2d + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) ).unsqueeze(2).unsqueeze(3) * scale + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -123,7 +129,7 @@ def merge_lora_models(models, ratios, merge_dtype): alphas[lora_module_name] = alpha if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - + print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge @@ -145,7 +151,7 @@ def merge_lora_models(models, ratios, merge_dtype): merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: merged_sd[key] = lora_sd[key] * scale - + # set alpha to sd for lora_module_name, alpha in base_alphas.items(): key = lora_module_name + ".alpha" diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 7d127ad..271de8e 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -3,11 +3,11 @@ # Thanks to cloneofsimo and kohya import argparse +import os import torch from safetensors.torch import load_file, save_file, safe_open from tqdm import tqdm from library import train_util, model_util -import numpy as np def load_state_dict(file_name, dtype): @@ -38,34 +38,12 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) -def index_sv_cumulative(S, target): - original_sum = float(torch.sum(S)) - cumulative_sums = torch.cumsum(S, dim=0)/original_sum - index = int(torch.searchsorted(cumulative_sums, target)) + 1 - if index >= len(S): - index = len(S) - 1 - - return index - - -def index_sv_fro(S, target): - S_squared = S.pow(2) - s_fro_sq = float(torch.sum(S_squared)) - sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq - index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - if index >= len(S): - index = len(S) - 1 - - return index - - -def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): +def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): network_alpha = None network_dim = None verbose_str = "\n" - fro_list = [] - CLAMP_QUANTILE = 1 # 0.99 + CLAMP_QUANTILE = 0.99 # Extract loaded lora dim and alpha for key, value in lora_sd.items(): @@ -79,12 +57,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn network_alpha = network_dim scale = network_alpha/network_dim + new_alpha = float(scale*new_rank) # calculate new alpha from scale - if dynamic_method: - print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}") - else: - new_alpha = float(scale*new_rank) # calculate new alpha from scale - print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}") + print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}") lora_down_weight = None lora_up_weight = None @@ -122,45 +97,11 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn U, S, Vh = torch.linalg.svd(full_weight_matrix) - if dynamic_method=="sv_ratio": - # Calculate new dim and alpha based off ratio - max_sv = S[0] - min_sv = max_sv/dynamic_param - new_rank = torch.sum(S > min_sv).item() - new_rank = max(new_rank, 1) - new_alpha = float(scale*new_rank) - - elif dynamic_method=="sv_cumulative": - # Calculate new dim and alpha based off cumulative sum - new_rank = index_sv_cumulative(S, dynamic_param) - new_rank = max(new_rank, 1) - new_alpha = float(scale*new_rank) - - elif dynamic_method=="sv_fro": - # Calculate new dim and alpha based off sqrt sum of squares - new_rank = index_sv_fro(S, dynamic_param) - new_rank = max(new_rank, 1) - new_alpha = float(scale*new_rank) - if verbose: s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) - - S_squared = S.pow(2) - s_fro = torch.sqrt(torch.sum(S_squared)) - s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) - fro_percent = float(s_red_fro/s_fro) - if not np.isnan(fro_percent): - fro_list.append(float(fro_percent)) - - verbose_str+=f"{block_down_name:75} | " - verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, fro retained: {fro_percent:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}" - - - if verbose and dynamic_method: - verbose_str+=f", dynamic | dim: {new_rank}, alpha: {new_alpha}\n" - else: - verbose_str+=f"\n" + verbose_str+=f"{block_down_name:76} | " + verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n" U = U[:, :new_rank] S = S[:new_rank] @@ -195,8 +136,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn if verbose: print(verbose_str) - - print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") print("resizing complete") return o_lora_sd, network_dim, new_alpha @@ -212,9 +151,6 @@ def resize(args): return torch.bfloat16 return None - if args.dynamic_method and not args.dynamic_param: - raise Exception("If using dynamic_method, then dynamic_param is required") - merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 save_dtype = str_to_dtype(args.save_precision) if save_dtype is None: @@ -224,22 +160,16 @@ def resize(args): lora_sd, metadata = load_state_dict(args.model, merge_dtype) print("resizing rank...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose) # update metadata if metadata is None: metadata = {} comment = metadata.get("ss_training_comment", "") - - if not args.dynamic_method: - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) - else: - metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" - metadata["ss_network_dim"] = 'Dynamic' - metadata["ss_network_alpha"] = 'Dynamic' + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -263,11 +193,6 @@ if __name__ == '__main__': parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する") - parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], - help="Specify dynamic resizing method, will override --new_rank") - parser.add_argument("--dynamic_param", type=float, default=None, - help="Specify target for dynamic reduction") - args = parser.parse_args() - resize(args) \ No newline at end of file + resize(args) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index c0448fc..c8e39b8 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -35,7 +35,8 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def merge_lora_models(models, ratios, new_rank, device, merge_dtype): +def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): + print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} for model, ratio in zip(models, ratios): print(f"loading: {model}") @@ -58,11 +59,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype): in_dim = down_weight.size()[1] out_dim = up_weight.size()[0] conv2d = len(down_weight.size()) == 4 - print(lora_module_name, network_dim, alpha, in_dim, out_dim) + kernel_size = None if not conv2d else down_weight.size()[2:4] + # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) # make original weight if not exist if lora_module_name not in merged_sd: - weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype) + weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) if device: weight = weight.to(device) else: @@ -77,9 +79,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype): scale = (alpha / network_dim) if not conv2d: # linear weight = weight + ratio * (up_weight @ down_weight) * scale - else: + elif kernel_size == (1, 1): weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) ).unsqueeze(2).unsqueeze(3) * scale + else: + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale merged_sd[lora_module_name] = weight @@ -89,16 +94,25 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype): with torch.no_grad(): for lora_module_name, mat in tqdm(list(merged_sd.items())): conv2d = (len(mat.size()) == 4) + kernel_size = None if not conv2d else mat.size()[2:4] + conv2d_3x3 = conv2d and kernel_size != (1, 1) + out_dim, in_dim = mat.size()[0:2] + if conv2d: - mat = mat.squeeze() + if conv2d_3x3: + mat = mat.flatten(start_dim=1) + else: + mat = mat.squeeze() + + module_new_rank = new_conv_rank if conv2d_3x3 else new_rank U, S, Vh = torch.linalg.svd(mat) - U = U[:, :new_rank] - S = S[:new_rank] + U = U[:, :module_new_rank] + S = S[:module_new_rank] U = U @ torch.diag(S) - Vh = Vh[:new_rank, :] + Vh = Vh[:module_new_rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) hi_val = torch.quantile(dist, CLAMP_QUANTILE) @@ -107,16 +121,16 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype): U = U.clamp(low_val, hi_val) Vh = Vh.clamp(low_val, hi_val) + if conv2d: + U = U.reshape(out_dim, module_new_rank, 1, 1) + Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1]) + up_weight = U down_weight = Vh - if conv2d: - up_weight = up_weight.unsqueeze(2).unsqueeze(3) - down_weight = down_weight.unsqueeze(2).unsqueeze(3) - merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() - merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank) + merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank) return merged_lora_sd @@ -138,7 +152,8 @@ def merge(args): if save_dtype is None: save_dtype = merge_dtype - state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype) + new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank + state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) print(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) @@ -158,6 +173,8 @@ if __name__ == '__main__': help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument("--new_conv_rank", type=int, default=None, + help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") args = parser.parse_args() diff --git a/tools/lycoris_locon_extract.py b/tools/lycoris_locon_extract.py new file mode 100644 index 0000000..308dcf1 --- /dev/null +++ b/tools/lycoris_locon_extract.py @@ -0,0 +1,115 @@ +import os, sys +sys.path.insert(0, os.getcwd()) +import argparse + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "base_model", help="The model which use it to train the dreambooth model", + default='', type=str + ) + parser.add_argument( + "db_model", help="the dreambooth model you want to extract the locon", + default='', type=str + ) + parser.add_argument( + "output_name", help="the output model", + default='./out.pt', type=str + ) + parser.add_argument( + "--is_v2", help="Your base/db model is sd v2 or not", + default=False, action="store_true" + ) + parser.add_argument( + "--device", help="Which device you want to use to extract the locon", + default='cpu', type=str + ) + parser.add_argument( + "--mode", + help=( + 'extraction mode, can be "fixed", "threshold", "ratio", "quantile". ' + 'If not "fixed", network_dim and conv_dim will be ignored' + ), + default='fixed', type=str + ) + parser.add_argument( + "--safetensors", help='use safetensors to save locon model', + default=False, action="store_true" + ) + parser.add_argument( + "--linear_dim", help="network dim for linear layer in fixed mode", + default=1, type=int + ) + parser.add_argument( + "--conv_dim", help="network dim for conv layer in fixed mode", + default=1, type=int + ) + parser.add_argument( + "--linear_threshold", help="singular value threshold for linear layer in threshold mode", + default=0., type=float + ) + parser.add_argument( + "--conv_threshold", help="singular value threshold for conv layer in threshold mode", + default=0., type=float + ) + parser.add_argument( + "--linear_ratio", help="singular ratio for linear layer in ratio mode", + default=0., type=float + ) + parser.add_argument( + "--conv_ratio", help="singular ratio for conv layer in ratio mode", + default=0., type=float + ) + parser.add_argument( + "--linear_quantile", help="singular value quantile for linear layer quantile mode", + default=1., type=float + ) + parser.add_argument( + "--conv_quantile", help="singular value quantile for conv layer quantile mode", + default=1., type=float + ) + return parser.parse_args() +ARGS = get_args() + + +from lycoris.utils import extract_diff +from lycoris.kohya_model_utils import load_models_from_stable_diffusion_checkpoint + +import torch +from safetensors.torch import save_file + + +def main(): + args = ARGS + base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) + db = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.db_model) + + linear_mode_param = { + 'fixed': args.linear_dim, + 'threshold': args.linear_threshold, + 'ratio': args.linear_ratio, + 'quantile': args.linear_quantile, + }[args.mode] + conv_mode_param = { + 'fixed': args.conv_dim, + 'threshold': args.conv_threshold, + 'ratio': args.conv_ratio, + 'quantile': args.conv_quantile, + }[args.mode] + + state_dict = extract_diff( + base, db, + args.mode, + linear_mode_param, conv_mode_param, + args.device + ) + + if args.safetensors: + save_file(state_dict, args.output_name) + else: + torch.save(state_dict, args.output_name) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/train_README-ja.md b/train_README-ja.md index bf0d9f9..479f960 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -1,4 +1,8 @@ -当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やスクリプトオプションについて説明します。 +__ドキュメント更新中のため記述に誤りがあるかもしれません。__ + +# 学習について、共通編 + +当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。 # 概要 @@ -8,15 +12,14 @@ 以下について説明します。 1. 学習データの準備について(設定ファイルを用いる新形式) -1. Aspect Ratio Bucketingについて +1. 学習で使われる用語のごく簡単な解説 1. 以前の指定形式(設定ファイルを用いずコマンドラインから指定) +1. 学習途中のサンプル画像生成 +1. 各スクリプトで共通の、よく使われるオプション 1. fine tuning 方式のメタデータ準備:キャプションニングなど 1.だけ実行すればとりあえず学習は可能です(学習については各スクリプトのドキュメントを参照)。2.以降は必要に応じて参照してください。 - # 学習データの準備について @@ -36,7 +39,7 @@ 1. fine tuning方式(正則化画像使用不可) - あらかじめキャプションをメタデータファイルにまとめます。タグとキャプションを分けて管理したり、学習を高速化するためlatentsを事前キャッシュしたりなどの機能をサポートします(いずれも別文書で説明しています)。 + あらかじめキャプションをメタデータファイルにまとめます。タグとキャプションを分けて管理したり、学習を高速化するためlatentsを事前キャッシュしたりなどの機能をサポートします(いずれも別文書で説明しています)。(fine tuning方式という名前ですが fine tuning 以外でも使えます。) 学習したいものと使用できる指定方法の組み合わせは以下の通りです。 @@ -124,7 +127,7 @@ batch_size = 4 # バッチサイズ num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい ``` -基本的には以下を場所のみ書き換えれば学習できます。 +基本的には以下の場所のみ書き換えれば学習できます。 1. 学習解像度 @@ -132,7 +135,7 @@ batch_size = 4 # バッチサイズ 1. バッチサイズ - 同時に何件のデータを学習するかを指定します。GPUのVRAMサイズ、学習解像度によって変わってきます。またfine tuning/DreamBooth/LoRA等でも変わってきますので、詳しくは各スクリプトの説明をご覧ください。 + 同時に何件のデータを学習するかを指定します。GPUのVRAMサイズ、学習解像度によって変わってきます。詳しくは後述します。またfine tuning/DreamBooth/LoRA等でも変わってきますので各スクリプトの説明もご覧ください。 1. フォルダ指定 @@ -248,7 +251,45 @@ batch_size = 4 # バッチサイズ それぞれのドキュメントを参考に学習を行ってください。 -# Aspect Ratio Bucketing について +# 学習で使われる用語のごく簡単な解説 + +細かいことは省略していますし私も完全には理解していないため、詳しくは各自お調べください。 + +## fine tuning(ファインチューニング) + +モデルを学習して微調整することを指します。使われ方によって意味が異なってきますが、狭義のfine tuningはStable Diffusionの場合、モデルを画像とキャプションで学習することです。DreamBoothは狭義のfine tuningのひとつの特殊なやり方と言えます。広義のfine tuningは、LoRAやTextual Inversion、Hypernetworksなどを含み、モデルを学習することすべてを含みます。 + +## ステップ + +ざっくりいうと学習データで1回計算すると1ステップです。「学習データのキャプションを今のモデルに流してみて、出てくる画像を学習データの画像と比較し、学習データに近づくようにモデルをわずかに変更する」のが1ステップです。 + +## バッチサイズ + +バッチサイズは1ステップで何件のデータをまとめて計算するかを指定する値です。まとめて計算するため速度は相対的に向上します。また一般的には精度も高くなるといわれています。 + +`バッチサイズ×ステップ数` が学習に使われるデータの件数になります。そのため、バッチサイズを増やした分だけステップ数を減らすとよいでしょう。 + +(ただし、たとえば「バッチサイズ1で1600ステップ」と「バッチサイズ4で400ステップ」は同じ結果にはなりません。同じ学習率の場合、一般的には後者のほうが学習不足になります。学習率を多少大きくするか(たとえば `2e-6` など)、ステップ数をたとえば500ステップにするなどして工夫してください。) + +バッチサイズを大きくするとその分だけGPUメモリを消費します。メモリが足りなくなるとエラーになりますし、エラーにならないギリギリでは学習速度が低下します。タスクマネージャーや `nvidia-smi` コマンドで使用メモリ量を確認しながら調整するとよいでしょう。 + +なお、バッチは「一塊のデータ」位の意味です。 + +## 学習率 + +ざっくりいうと1ステップごとにどのくらい変化させるかを表します。大きな値を指定するとそれだけ速く学習が進みますが、変化しすぎてモデルが壊れたり、最適な状態にまで至れない場合があります。小さい値を指定すると学習速度は遅くなり、また最適な状態にやはり至れない場合があります。 + +fine tuning、DreamBoooth、LoRAそれぞれで大きく異なり、また学習データや学習させたいモデル、バッチサイズやステップ数によっても変わってきます。一般的な値から初めて学習状態を見ながら増減してください。 + +デフォルトでは学習全体を通して学習率は固定です。スケジューラの指定で学習率をどう変化させるか決められますので、それらによっても結果は変わってきます。 + +## エポック(epoch) + +学習データが一通り学習されると(データが一周すると)1 epochです。繰り返し回数を指定した場合は、その繰り返し後のデータが一周すると1 epochです。 + +1 epochのステップ数は、基本的には `データ件数÷バッチサイズ` ですが、Aspect Ratio Bucketing を使うと微妙に増えます(異なるbucketのデータは同じバッチにできないため、ステップ数が増えます)。 + +## Aspect Ratio Bucketing Stable Diffusion のv1は512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくキャプションと画像の関係が学習されることが期待されます。 @@ -260,11 +301,15 @@ Stable Diffusion のv1は512\*512で学習されていますが、それに加 機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。 -# 以前のデータ指定方法 +# 以前の指定形式(設定ファイルを用いずコマンドラインから指定) -フォルダ名で繰り返し回数を指定する方法です。 +`.toml` ファイルを指定せずコマンドラインオプションで指定する方法です。DreamBooth class+identifier方式、DreamBooth キャプション方式、fine tuning方式があります。 -## step 1. 学習用画像の準備 +## DreamBooth、class+identifier方式 + +フォルダ名で繰り返し回数を指定します。また `train_data_dir` オプションと `reg_data_dir` オプションを用います。 + +### step 1. 学習用画像の準備 学習用画像を格納するフォルダを作成します。 __さらにその中に__ 、以下の名前でディレクトリを作成します。 @@ -294,15 +339,7 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ - reg_girls - 1_1girl -### DreamBoothでキャプションを使う - -学習用画像、正則化画像のフォルダに、画像と同じファイル名で、拡張子.caption(オプションで変えられます)のファイルを置くと、そのファイルからキャプションを読み込みプロンプトとして学習します。 - -※それらの画像の学習に、フォルダ名(identifier class)は使用されなくなります。 - -キャプションファイルの拡張子はデフォルトで.captionです。学習スクリプトの `--caption_extension` オプションで変更できます。`--shuffle_caption` オプションで学習時のキャプションについて、カンマ区切りの各部分をシャッフルしながら学習します。 - -## step 2. 正則化画像の準備 +### step 2. 正則化画像の準備 正則化画像を使う場合の手順です。 @@ -313,16 +350,288 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ ![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png) -## step 3. 学習の実行 +### step 3. 学習の実行 各学習スクリプトを実行します。 `--train_data_dir` オプションで前述の学習用データのフォルダを(__画像を含むフォルダではなく、その親フォルダ__)、`--reg_data_dir` オプションで正則化画像のフォルダ(__画像を含むフォルダではなく、その親フォルダ__)を指定してください。 - # メタデータファイルの作成 diff --git a/train_db_README-ja.md b/train_db_README-ja.md index 85ae35a..0d0747b 100644 --- a/train_db_README-ja.md +++ b/train_db_README-ja.md @@ -1,75 +1,104 @@ -DreamBoothのガイドです。LoRA等の追加ネットワークの学習にも同じ手順を使います。 +DreamBoothのガイドです。 + +[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 # 概要 +DreamBoothとは、画像生成モデルに特定の主題を追加学習し、それを特定の識別子で生成する技術です。[論文はこちら](https://arxiv.org/abs/2208.12242)。 + +具体的には、Stable Diffusionのモデルにキャラや画風などを学ばせ、それを `shs` のような特定の単語で呼び出せる(生成画像に出現させる)ことができます。 + +スクリプトは[DiffusersのDreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)を元にしていますが、以下のような機能追加を行っています(いくつかの機能は元のスクリプト側もその後対応しています)。 + スクリプトの主な機能は以下の通りです。 -- 8bit Adam optimizerおよびlatentのキャッシュによる省メモリ化(ShivamShrirao氏版と同様)。 +- 8bit Adam optimizerおよびlatentのキャッシュによる省メモリ化([Shivam Shrirao氏版](https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth)と同様)。 - xformersによる省メモリ化。 - 512x512だけではなく任意サイズでの学習。 - augmentationによる品質の向上。 - DreamBoothだけではなくText Encoder+U-Netのfine tuningに対応。 -- StableDiffusion形式でのモデルの読み書き。 +- Stable Diffusion形式でのモデルの読み書き。 - Aspect Ratio Bucketing。 - Stable Diffusion v2.0対応。 # 学習の手順 -## step 1. 環境整備 +あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 -このリポジトリのREADMEを参照してください。 +## データの準備 +[学習データの準備について](./train_README-ja.md) を参照してください。 -## step 2. identifierとclassを決める +## 学習の実行 -学ばせたい対象を結びつける単語identifierと、対象の属するclassを決めます。 - -(instanceなどいろいろな呼び方がありますが、とりあえず元の論文に合わせます。) - -以下ごく簡単に説明します(詳しくは調べてください)。 - -classは学習対象の一般的な種別です。たとえば特定の犬種を学ばせる場合には、classはdogになります。アニメキャラならモデルによりboyやgirl、1boyや1girlになるでしょう。 - -identifierは学習対象を識別して学習するためのものです。任意の単語で構いませんが、元論文によると「tokinizerで1トークンになる3文字以下でレアな単語」が良いとのことです。 - -identifierとclassを使い、たとえば「shs dog」などでモデルを学習することで、学習させたい対象をclassから識別して学習できます。 - -画像生成時には「shs dog」とすれば学ばせた犬種の画像が生成されます。 - -(identifierとして私が最近使っているものを参考までに挙げると、``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny`` などです。) - -## step 3. 学習用画像の準備 -学習用画像を格納するフォルダを作成します。 __さらにその中に__ 、以下の名前でディレクトリを作成します。 +スクリプトを実行します。最大限、メモリを節約したコマンドは以下のようになります(実際には1行で入力します)。それぞれの行を必要に応じて書き換えてください。12GB程度のVRAMで動作するようです。 ``` -<繰り返し回数>_ +accelerate launch --num_cpu_threads_per_process 1 train_db.py + --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> + --dataset_config=<データ準備で作成した.tomlファイル> + --output_dir=<学習したモデルの出力先フォルダ> + --output_name=<学習したモデル出力時のファイル名> + --save_model_as=safetensors + --prior_loss_weight=1.0 + --max_train_steps=1600 + --learning_rate=1e-6 + --optimizer_type="AdamW8bit" + --xformers + --mixed_precision="fp16" + --cache_latents + --gradient_checkpointing ``` -間の``_``を忘れないでください。 +`num_cpu_threads_per_process` には通常は1を指定するとよいようです。 -繰り返し回数は、正則化画像と枚数を合わせるために指定します(後述します)。 +`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。 -たとえば「sls frog」というプロンプトで、データを20回繰り返す場合、「20_sls frog」となります。以下のようになります。 +`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。 -![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png) +`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。 -## step 4. 正則化画像の準備 -正則化画像を使う場合の手順です。使わずに学習することもできます(正則化画像を使わないと区別ができなくなるので対象class全体が影響を受けます)。 +`prior_loss_weight` は正則化画像のlossの重みです。通常は1.0を指定します。 -正則化画像を格納するフォルダを作成します。 __さらにその中に__ ``<繰り返し回数>_`` という名前でディレクトリを作成します。 +学習させるステップ数 `max_train_steps` を1600とします。学習率 `learning_rate` はここでは1e-6を指定しています。 -たとえば「frog」というプロンプトで、データを繰り返さない(1回だけ)場合、以下のようになります。 +省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。 -![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png) +オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。 -繰り返し回数は「 __学習用画像の繰り返し回数×学習用画像の枚数≧正則化画像の繰り返し回数×正則化画像の枚数__ 」となるように指定してください。 +`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 -(1 epochのデータ数が「学習用画像の繰り返し回数×学習用画像の枚数」となります。正則化画像の枚数がそれより多いと、余った部分の正則化画像は使用されません。) +省メモリ化のため `cache_latents` オプションを指定してVAEの出力をキャッシュします。 -## step 5. 学習の実行 -スクリプトを実行します。最大限、メモリを節約したコマンドは以下のようになります(実際には1行で入力します)。 +ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。また `cache_latents` を外すことで augmentation が可能になります。 -※LoRA等の追加ネットワークを学習する場合のコマンドは ``train_db.py`` ではなく ``train_network.py`` となります。また追加でnetwork_\*オプションが必要となりますので、LoRAのガイドを参照してください。 +### よく使われるオプションについて + +以下の場合には [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」を参照してください。 + +- Stable Diffusion 2.xまたはそこからの派生モデルを学習する +- clip skipを2以上を前提としたモデルを学習する +- 75トークンを超えたキャプションで学習する + +### DreamBoothでのステップ数について + +当スクリプトでは省メモリ化のため、ステップ当たりの学習回数が元のスクリプトの半分になっています(対象の画像と正則化画像を同一のバッチではなく別のバッチに分割して学習するため)。 + +元のDiffusers版やXavierXiao氏のStable Diffusion版とほぼ同じ学習を行うには、ステップ数を倍にしてください。 + +(学習画像と正則化画像をまとめてから shuffle するため厳密にはデータの順番が変わってしまいますが、学習には大きな影響はないと思います。) + +### DreamBoothでのバッチサイズについて + +モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(fine tuningと同じ)。 + +### 学習率について + +Diffusers版では5e-6ですがStable Diffusion版は1e-6ですので、上のサンプルでは1e-6を指定しています。 + +### 以前の形式のデータセット指定をした場合のコマンドライン + +解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。 ``` accelerate launch --num_cpu_threads_per_process 1 train_db.py @@ -77,6 +106,7 @@ accelerate launch --num_cpu_threads_per_process 1 train_db.py --train_data_dir=<学習用データのディレクトリ> --reg_data_dir=<正則化画像のディレクトリ> --output_dir=<学習したモデルの出力先ディレクトリ> + --output_name=<学習したモデル出力時のファイル名> --prior_loss_weight=1.0 --resolution=512 --train_batch_size=1 @@ -89,43 +119,33 @@ accelerate launch --num_cpu_threads_per_process 1 train_db.py --gradient_checkpointing ``` -num_cpu_threads_per_processには通常は1を指定するとよいようです。 +## 学習したモデルで画像生成する -pretrained_model_name_or_pathに追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。学習後のモデルの保存形式はデフォルトでは元のモデルと同じになります(save_model_asオプションで変更できます)。 +学習が終わると指定したフォルダに指定した名前でsafetensorsファイルが出力されます。 -prior_loss_weightは正則化画像のlossの重みです。通常は1.0を指定します。 +v1.4/1.5およびその他の派生モデルの場合、このモデルでAutomatic1111氏のWebUIなどで推論できます。models\Stable-diffusionフォルダに置いてください。 -resolutionは画像のサイズ(解像度、幅と高さ)になります。bucketing(後述)を用いない場合、学習用画像、正則化画像はこのサイズとしてください。 +v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述された.yamlファイルが別途必要になります。v2.x baseの場合はv2-inference.yamlを、768/vの場合はv2-inference-v.yamlを、同じフォルダに置き、拡張子の前の部分をモデルと同じ名前にしてください。 -train_batch_sizeは学習時のバッチサイズです。max_train_stepsを1600とします。学習率learning_rateは、diffusers版では5e-6ですがStableDiffusion版は1e-6ですのでここでは1e-6を指定しています。 +![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) -省メモリ化のためmixed_precision="bf16"(または"fp16")、およびgradient_checkpointing を指定します。 +各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。 -xformersオプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合、エラーとなる場合(mixed_precisionなしの場合、私の環境ではエラーとなりました)、代わりにmem_eff_attnオプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 +# DreamBooth特有のその他の主なオプション -省メモリ化のためcache_latentsオプションを指定してVAEの出力をキャッシュします。 +すべてのオプションについては別文書を参照してください。 -ある程度メモリがある場合はたとえば以下のように指定します。 +## Text Encoderの学習を途中から行わない --stop_text_encoder_training -``` -accelerate launch --num_cpu_threads_per_process 8 train_db.py - --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> - --train_data_dir=<学習用データのディレクトリ> - --reg_data_dir=<正則化画像のディレクトリ> - --output_dir=<学習したモデルの出力先ディレクトリ> - --prior_loss_weight=1.0 - --resolution=512 - --train_batch_size=4 - --learning_rate=1e-6 - --max_train_steps=400 - --use_8bit_adam - --xformers - --mixed_precision="bf16" - --cache_latents -``` +stop_text_encoder_trainingオプションに数値を指定すると、そのステップ数以降はText Encoderの学習を行わずU-Netだけ学習します。場合によっては精度の向上が期待できるかもしれません。 -gradient_checkpointingを外し高速化します(メモリ使用量は増えます)。バッチサイズを増やし、高速化と精度向上を図ります。 +(恐らくText Encoderだけ先に過学習することがあり、それを防げるのではないかと推測していますが、詳細な影響は不明です。) +## Tokenizerのパディングをしない --no_token_padding +no_token_paddingオプションを指定するとTokenizerの出力をpaddingしません(Diffusers版の旧DreamBoothと同じ動きになります)。 + + + diff --git a/train_network.py b/train_network.py index 4d5ffd3..cf64c89 100644 --- a/train_network.py +++ b/train_network.py @@ -427,10 +427,13 @@ def train(args): "ss_bucket_info": json.dumps(dataset.bucket_info), }) + # add extra args if args.network_args: - for key, value in net_kwargs.items(): - metadata["ss_arg_" + key] = value + metadata["ss_network_args"] = json.dumps(net_kwargs) + # for key, value in net_kwargs.items(): + # metadata["ss_arg_" + key] = value + # model name and hash if args.pretrained_model_name_or_path is not None: sd_model_name = args.pretrained_model_name_or_path if os.path.exists(sd_model_name): @@ -449,6 +452,13 @@ def train(args): metadata = {k: str(v) for k, v in metadata.items()} + # make minimum metadata for filtering + minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"] + minimum_metadata = {} + for key in minimum_keys: + if key in metadata: + minimum_metadata[key] = metadata[key] + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -564,7 +574,7 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) metadata["ss_training_finished_at"] = str(time.time()) print(f"saving checkpoint: {ckpt_file}") - unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) + unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as @@ -603,7 +613,7 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) + network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) print("model saved.") @@ -639,4 +649,4 @@ if __name__ == '__main__': help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") args = parser.parse_args() - train(args) \ No newline at end of file + train(args) diff --git a/train_network_README-ja.md b/train_network_README-ja.md index 4a507ff..4a79a6f 100644 --- a/train_network_README-ja.md +++ b/train_network_README-ja.md @@ -1,118 +1,99 @@ -## LoRAの学習について +# LoRAの学習について [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)(arxiv)、[LoRA](https://github.com/microsoft/LoRA)(github)をStable Diffusionに適用したものです。 [cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を大いに参考にさせていただきました。ありがとうございます。 +通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。 + +Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。 + 8GB VRAMでもぎりぎり動作するようです。 +[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 + ## 学習したモデルに関する注意 cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。 WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。 -## 学習方法 +# 学習の手順 -train_network.pyを用います。 +あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 -DreamBoothの手法(identifier(sksなど)とclass、オプションで正則化画像を用いる)と、キャプションを用いるfine tuningの手法の両方で学習できます。 +## データの準備 -どちらの方法も既存のスクリプトとほぼ同じ方法で学習できます。異なる点については後述します。 +[学習データの準備について](./train_README-ja.md) を参照してください。 -### DreamBoothの手法を用いる場合 -[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。 +## 学習の実行 -学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。 +`train_network.py`を用います。 -ほぼすべてのオプション(Stable Diffusionのモデル保存関係を除く)が使えますが、stop_text_encoder_trainingはサポートしていません。 - -### キャプションを用いる場合 - -[fine-tuningのガイド](./fine_tune_README_ja.md) を参照し、各手順を実行してください。 - -学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプション(モデル保存関係を除く)がそのまま使えます。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。 - -なお「latentsの事前取得」は行わなくても動作します。VAEから学習時(またはキャッシュ時)にlatentを取得するため学習速度は遅くなりますが、代わりにcolor_augが使えるようになります。 - -### LoRAの学習のためのオプション - -train_network.pyでは--network_moduleオプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。 +`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。 なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。 -以下はコマンドラインの例です(DreamBooth手法)。 +以下はコマンドラインの例です。 ``` accelerate launch --num_cpu_threads_per_process 1 train_network.py - --pretrained_model_name_or_path=..\models\model.ckpt - --train_data_dir=..\data\db\char1 --output_dir=..\lora_train1 - --reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0 - --resolution=448,640 --train_batch_size=1 --learning_rate=1e-4 - --max_train_steps=400 --optimizer_type=AdamW8bit --xformers --mixed_precision=fp16 - --save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug + --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> + --dataset_config=<データ準備で作成した.tomlファイル> + --output_dir=<学習したモデルの出力先フォルダ> + --output_name=<学習したモデル出力時のファイル名> + --save_model_as=safetensors + --prior_loss_weight=1.0 + --max_train_steps=400 + --learning_rate=1e-4 + --optimizer_type="AdamW8bit" + --xformers + --mixed_precision="fp16" + --cache_latents + --gradient_checkpointing + --save_every_n_epochs=1 --network_module=networks.lora ``` -(2023/2/22:オプティマイザの指定方法が変わりました。[こちら](#オプティマイザの指定について)をご覧ください。) - ---output_dirオプションで指定したフォルダに、LoRAのモデルが保存されます。 +`--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。 その他、以下のオプションが指定できます。 -* --network_dim +* `--network_dim` * LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。 -* --network_alpha +* `--network_alpha` * アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。 -* --network_weights +* `--network_weights` * 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。 -* --network_train_unet_only +* `--network_train_unet_only` * U-Netに関連するLoRAモジュールのみ有効とします。fine tuning的な学習で指定するとよいかもしれません。 -* --network_train_text_encoder_only +* `--network_train_text_encoder_only` * Text Encoderに関連するLoRAモジュールのみ有効とします。Textual Inversion的な効果が期待できるかもしれません。 -* --unet_lr +* `--unet_lr` * U-Netに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。 -* --text_encoder_lr +* `--text_encoder_lr` * Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。 +* `--network_args` + * 複数の引数を指定できます。後述します。 ---network_train_unet_onlyと--network_train_text_encoder_onlyの両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。 +`--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。 -## オプティマイザの指定について +## LoRA を Conv2d に拡大して適用する ---optimizer_type オプションでオプティマイザの種類を指定します。以下が指定できます。 +通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。 -- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html) - - 過去のバージョンのオプション未指定時と同じ -- AdamW8bit : 引数は同上 - - 過去のバージョンの--use_8bit_adam指定時と同じ -- Lion : https://github.com/lucidrains/lion-pytorch - - 過去のバージョンの--use_lion_optimizer指定時と同じ -- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True -- SGDNesterov8bit : 引数は同上 -- DAdaptation : https://github.com/facebookresearch/dadaptation -- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) -- 任意のオプティマイザ +`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。 -オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。 +``` +--network_args "conv_dim=1" "conv_alpha=1" +``` -オプション引数を指定する場合は、それぞれのオプティマイザの仕様をご確認ください。 +以下のように alpha 省略時は1になります。 -一部のオプティマイザでは必須の引数があり、省略すると自動的に追加されます(SGDNesterovのmomentumなど)。コンソールの出力を確認してください。 - -D-Adaptationオプティマイザは学習率を自動調整します。学習率のオプションに指定した値は学習率そのものではなくD-Adaptationが決定した学習率の適用率になりますので、通常は1.0を指定してください。Text EncoderにU-Netの半分の学習率を指定したい場合は、``--text_encoder_lr=0.5 --unet_lr=1.0``と指定します。 - -AdaFactorオプティマイザはrelative_step=Trueを指定すると学習率を自動調整できます(省略時はデフォルトで追加されます)。自動調整する場合は学習率のスケジューラにはadafactor_schedulerが強制的に使用されます。またscale_parameterとwarmup_initを指定するとよいようです。 - -自動調整する場合のオプション指定はたとえば ``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"`` のようになります。 - -学習率を自動調整しない場合はオプション引数 ``relative_step=False`` を追加してください。その場合、学習率のスケジューラにはconstant_with_warmupが、また勾配のclip normをしないことが推奨されているようです。そのため引数は ``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`` のようになります。 - -### 任意のオプティマイザを使う - -``torch.optim`` のオプティマイザを使う場合にはクラス名のみを(``--optimizer_type=RMSprop``など)、他のモジュールのオプティマイザを使う時は「モジュール名.クラス名」を指定してください(``--optimizer_type=bitsandbytes.optim.lamb.LAMB``など)。 - -(内部でimportlibしているだけで動作は未確認です。必要ならパッケージをインストールしてください。) +``` +--network_args "conv_dim=1" +``` ## マージスクリプトについて @@ -176,6 +157,27 @@ v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha`` * save_precision * モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。 + +## 複数のrankが異なるLoRAのモデルをマージする + +複数のLoRAをひとつのLoRAで近似します(完全な再現はできません)。`svd_merge_lora.py`を用います。たとえば以下のようなコマンドラインになります。 + +``` +python networks\svd_merge_lora.py + --save_to ..\lora_train1\model-char1-style1-merged.safetensors + --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors + --ratios 0.6 0.4 --new_rank 32 --device cuda +``` + +`merge_lora.py` と主なオプションは同一です。以下のオプションが追加されています。 + +- `--new_rank` + - 作成するLoRAのrankを指定します。 +- `--new_conv_rank` + - 作成する Conv2d 3x3 LoRA の rank を指定します。省略時は `new_rank` と同じになります。 +- `--device` + - `--device cuda`としてcudaを指定すると計算をGPU上で行います。処理が速くなります。 + ## 当リポジトリ内の画像生成スクリプトで生成する gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。 @@ -209,12 +211,14 @@ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRA ### その他のオプション -- --v2 +- `--v2` - v2.xのStable Diffusionモデルを使う場合に指定してください。 -- --device +- `--device` - ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。 -- --save_precision +- `--save_precision` - LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。 +- `--conv_dim` + - 指定するとLoRAの適用範囲を Conv2d 3x3 へ拡大します。Conv2d 3x3 の rank を指定します。 ## 画像リサイズスクリプト @@ -252,7 +256,7 @@ python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256 ### cloneofsimo氏のリポジトリとの違い -12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。 +2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。 またモジュール入れ替え機構は全く異なります。 diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 810e850..34b7f09 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -172,8 +172,6 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - prompt_replacement = None - # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: print("use template for training captions. is object: {args.use_object_template}") @@ -183,6 +181,11 @@ def train(args): for tmpl in templates: captions.append(tmpl.format(replace_to)) train_dataset_group.add_replacement("", captions) + + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None else: if args.num_vectors_per_token > 1: replace_to = " ".join(token_strings) diff --git a/train_ti_README-ja.md b/train_ti_README-ja.md index 90989ec..9087369 100644 --- a/train_ti_README-ja.md +++ b/train_ti_README-ja.md @@ -1,32 +1,41 @@ -## Textual Inversionの学習について +[Textual Inversion](https://textual-inversion.github.io/) の学習についての説明です。 -[Textual Inversion](https://textual-inversion.github.io/)です。実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。 +[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 -学習したモデルはWeb UIでもそのまま使えます。 +実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。 -なお恐らくSD2.xにも対応していますが現時点では未テストです。 +学習したモデルはWeb UIでもそのまま使えます。なお恐らくSD2.xにも対応していますが現時点では未テストです。 -## 学習方法 +# 学習の手順 -``train_textual_inversion.py`` を用います。 +あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 -データの準備については ``train_network.py`` と全く同じですので、[そちらのドキュメント](./train_network_README-ja.md)を参照してください。 +## データの準備 -## オプション +[学習データの準備について](./train_README-ja.md) を参照してください。 -以下はコマンドラインの例です(DreamBooth手法)。 +## 学習の実行 + +``train_textual_inversion.py`` を用います。以下はコマンドラインの例です(DreamBooth手法)。 ``` accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py - --pretrained_model_name_or_path=..\models\model.ckpt - --train_data_dir=..\data\db\char1 --output_dir=..\ti_train1 - --resolution=448,640 --train_batch_size=1 --learning_rate=1e-4 - --max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16 - --save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug + --dataset_config=<データ準備で作成した.tomlファイル> + --output_dir=<学習したモデルの出力先フォルダ> + --output_name=<学習したモデル出力時のファイル名> + --save_model_as=safetensors + --prior_loss_weight=1.0 + --max_train_steps=1600 + --learning_rate=1e-6 + --optimizer_type="AdamW8bit" + --xformers + --mixed_precision="fp16" + --cache_latents + --gradient_checkpointing --token_string=mychar4 --init_word=cute --num_vectors_per_token=4 ``` -``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてください(token_stringがmychar4なら、``mychar4 1girl`` など)__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。 +``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてください(token_stringがmychar4なら、``mychar4 1girl`` など)__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。DreamBooth, class+identifier形式のデータセットとして、`token_string` をトークン文字列にするのが最も簡単で確実です。 プロンプトにトークン文字列が含まれているかどうかは、``--debug_dataset`` で置換後のtoken idが表示されますので、以下のように ``49408`` 以降のtokenが存在するかどうかで確認できます。 @@ -47,14 +56,47 @@ tokenizerがすでに持っている単語(一般的な単語)は使用で ``--num_vectors_per_token`` にいくつのトークンをこの学習で使うかを指定します。多いほうが表現力が増しますが、その分多くのトークンを消費します。たとえばnum_vectors_per_token=8の場合、指定したトークン文字列は(一般的なプロンプトの77トークン制限のうち)8トークンを消費します。 +以上がTextual Inversionのための主なオプションです。以降は他の学習スクリプトと同様です。 -その他、以下のオプションが指定できます。 +`num_cpu_threads_per_process` には通常は1を指定するとよいようです。 -* --weights +`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。 + +`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。 + +`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。 + +学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。 + +省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。 + +オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。 + +`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 + +ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `8` くらいに増やしてください(高速化と精度向上の可能性があります)。 + +### よく使われるオプションについて + +以下の場合にはオプションに関するドキュメントを参照してください。 + +- Stable Diffusion 2.xまたはそこからの派生モデルを学習する +- clip skipを2以上を前提としたモデルを学習する +- 75トークンを超えたキャプションで学習する + +### Textual Inversionでのバッチサイズについて + +モデル全体を学習するDreamBoothやfine tuningに比べてメモリ使用量が少ないため、バッチサイズは大きめにできます。 + +# Textual Inversionのその他の主なオプション + +すべてのオプションについては別文書を参照してください。 + +* `--weights` * 学習前に学習済みのembeddingsを読み込み、そこから追加で学習します。 -* --use_object_template +* `--use_object_template` * キャプションではなく既定の物体用テンプレート文字列(``a photo of a {}``など)で学習します。公式実装と同じになります。キャプションは無視されます。 -* --use_style_template +* `--use_style_template` * キャプションではなく既定のスタイル用テンプレート文字列で学習します(``a painting in the style of {}``など)。公式実装と同じになります。キャプションは無視されます。 ## 当リポジトリ内の画像生成スクリプトで生成する From fc5d2b2c31c7170c91e82e66b2b6b3024fce3918 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Fri, 10 Mar 2023 11:44:52 -0500 Subject: [PATCH 2/4] Update to sd-script dev code base --- README.md | 3 + library/resize_lora_gui.py | 6 +- library/svd_merge_lora_gui.py | 187 +++++++++++++++++++++ library/train_util.py | 35 +++- lora_gui.py | 2 + networks/extract_lora_from_models.py | 35 ++-- networks/lora.py | 49 +++--- networks/resize_lora.py | 237 +++++++++++++++++++++------ networks/svd_merge_lora.py | 21 +-- train_README-ja.md | 8 + train_network.py | 41 +++-- train_network_README-ja.md | 4 + 12 files changed, 499 insertions(+), 129 deletions(-) create mode 100644 library/svd_merge_lora_gui.py diff --git a/README.md b/README.md index a468a4e..40a7b4a 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,9 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/03/10 (v21.2.1): + - Update to latest sd-script code + - Add support for SVD based LoRA merge * 2023/03/09 (v21.2.0): - Fix issue https://github.com/bmaltais/kohya_ss/issues/335 - Add option to print LoRA trainer command without executing it diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index a94fdb7..6b4396b 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -148,11 +148,11 @@ def gradio_resize_lora_tab(): value='fp16', interactive=True, ) - device = gr.Textbox( + device = gr.Dropdown( label='Device', - placeholder='{Optional) device to use, cuda for GPU. Default: cuda', - interactive=True, + choices=['cpu', 'cuda',], value='cuda', + interactive=True, ) convert_button = gr.Button('Resize model') diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py new file mode 100644 index 0000000..b34b503 --- /dev/null +++ b/library/svd_merge_lora_gui.py @@ -0,0 +1,187 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +document_symbol = '\U0001F4C4' # 📄 +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def svd_merge_lora( + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, + new_rank, + new_conv_rank, + device, +): + # Check for caption_text_input + if lora_a_model == '': + msgbox('Invalid model A file') + return + + if lora_b_model == '': + msgbox('Invalid model B file') + return + + # Check if source model exist + if not os.path.isfile(lora_a_model): + msgbox('The provided model A is not a file') + return + + if not os.path.isfile(lora_b_model): + msgbox('The provided model B is not a file') + return + + ratio_a = ratio + ratio_b = 1 - ratio + + run_cmd = f'{PYTHON} "{os.path.join("networks","svd_merge_lora.py")}"' + run_cmd += f' --save_precision {save_precision}' + run_cmd += f' --precision {precision}' + run_cmd += f' --save_to "{save_to}"' + run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"' + run_cmd += f' --ratios {ratio_a} {ratio_b}' + run_cmd += f' --device {device}' + run_cmd += f' --new_rank "{new_rank}"' + run_cmd += f' --new_conv_rank "{new_conv_rank}"' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + +### +# Gradio UI +### + + +def gradio_svd_merge_lora_tab(): + with gr.Tab('Merge LoRA (SVD)'): + gr.Markdown('This utility can merge two LoRA networks together.') + + lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + + with gr.Row(): + lora_a_model = gr.Textbox( + label='LoRA model "A"', + placeholder='Path to the LoRA A model', + interactive=True, + ) + button_lora_a_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_a_model_file.click( + get_file_path, + inputs=[lora_a_model, lora_ext, lora_ext_name], + outputs=lora_a_model, + show_progress=False, + ) + + lora_b_model = gr.Textbox( + label='LoRA model "B"', + placeholder='Path to the LoRA B model', + interactive=True, + ) + button_lora_b_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_b_model_file.click( + get_file_path, + inputs=[lora_b_model, lora_ext, lora_ext_name], + outputs=lora_b_model, + show_progress=False, + ) + with gr.Row(): + ratio = gr.Slider( + label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B', + minimum=0, + maximum=1, + step=0.01, + value=0.5, + interactive=True, + ) + new_rank = gr.Slider( + label='New Rank', + minimum=1, + maximum=1024, + step=1, + value=128, + interactive=True, + ) + new_conv_rank = gr.Slider( + label='New Conv Rank', + minimum=1, + maximum=1024, + step=1, + value=128, + interactive=True, + ) + + with gr.Row(): + save_to = gr.Textbox( + label='Save to', + placeholder='path for the file to save...', + interactive=True, + ) + button_save_to = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_save_to.click( + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, + show_progress=False, + ) + precision = gr.Dropdown( + label='Merge precision', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + save_precision = gr.Dropdown( + label='Save precision', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + device = gr.Dropdown( + label='Device', + choices=['cpu', 'cuda',], + value='cuda', + interactive=True, + ) + + convert_button = gr.Button('Merge model') + + convert_button.click( + svd_merge_lora, + inputs=[ + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, + new_rank, + new_conv_rank, + device, + ], + show_progress=False, + ) diff --git a/library/train_util.py b/library/train_util.py index 6af1abe..718fe36 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -912,10 +912,14 @@ class FineTuningDataset(BaseDataset): if os.path.exists(image_key): abs_path = image_key else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(subset.image_dir, image_key) - assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" - abs_path = abs_path[0] + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + else: + # わりといい加減だがいい方法が思いつかん + abs_path = glob_images(subset.image_dir, image_key) + assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" + abs_path = abs_path[0] caption = img_md.get('caption') tags = img_md.get('tags') @@ -1757,15 +1761,22 @@ def get_optimizer(args, trainable_params): raise ImportError("No dadaptation / dadaptation がインストールされていないようです") print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") - min_lr = lr + actual_lr = lr + lr_count = 1 if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) for group in trainable_params: - min_lr = min(min_lr, group.get("lr", lr)) + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) - if min_lr <= 0.1: + if actual_lr <= 0.1: print( - f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}') + f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}') print('recommend option: lr=1.0 / 推奨は1.0です') + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}") optimizer_class = dadaptation.DAdaptAdam optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -2296,6 +2307,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v with torch.no_grad(): with accelerator.autocast(): for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue prompt = prompt.strip() if len(prompt) == 0 or prompt[0] == '#': continue @@ -2355,6 +2368,12 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v height = max(64, height - height % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8 + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) diff --git a/lora_gui.py b/lora_gui.py index d175b30..49918de 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -38,6 +38,7 @@ from library.tensorboard_gui import ( from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab from library.merge_lora_gui import gradio_merge_lora_tab +from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab from library.verify_lora_gui import gradio_verify_lora_tab from library.resize_lora_gui import gradio_resize_lora_tab from library.sampler_gui import sample_gradio_config, run_cmd_sample @@ -879,6 +880,7 @@ def lora_tab( ) gradio_dataset_balancing_tab() gradio_merge_lora_tab() + gradio_svd_merge_lora_tab() gradio_resize_lora_tab() gradio_verify_lora_tab() diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 9f40978..28b905f 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -103,7 +103,8 @@ def svd(args): if args.device: mat = mat.to(args.device) - # print(mat.size(), mat.device, rank, in_dim, out_dim) + + # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim if conv2d: @@ -112,7 +113,7 @@ def svd(args): else: mat = mat.squeeze() - U, S, Vh = torch.linalg.svd(mat) + U, S, Vh = torch.linalg.svd(mat.to("cuda")) U = U[:, :rank] S = S[:rank] @@ -137,27 +138,17 @@ def svd(args): lora_weights[lora_name] = (U, Vh) # make state dict for LoRA - lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict - lora_sd = lora_network_o.state_dict() - print(f"LoRA has {len(lora_sd)} weights.") - - for key in list(lora_sd.keys()): - if "alpha" in key: - continue - - lora_name = key.split('.')[0] - i = 0 if "lora_up" in key else 1 - - weights = lora_weights[lora_name][i] - # print(key, i, weights.size(), lora_sd[key].size()) - # if len(lora_sd[key].size()) == 4: - # weights = weights.unsqueeze(2).unsqueeze(3) - - assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" - lora_sd[key] = weights + lora_sd = {} + for lora_name, (up_weight, down_weight) in lora_weights.items(): + lora_sd[lora_name + '.lora_up.weight'] = up_weight + lora_sd[lora_name + '.lora_down.weight'] = down_weight + lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) # load state dict to LoRA and save it - info = lora_network_o.load_state_dict(lora_sd) + lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict + + info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") dir_name = os.path.dirname(args.save_to) @@ -167,7 +158,7 @@ def svd(args): # minimum metadata metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} - lora_network_o.save_weights(args.save_to, save_dtype, metadata) + lora_network_save.save_weights(args.save_to, save_dtype, metadata) print(f"LoRA weights are saved to: {args.save_to}") diff --git a/networks/lora.py b/networks/lora.py index c0181c0..6d3875d 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -21,30 +21,34 @@ class LoRAModule(torch.nn.Module): """ if alpha == 0 or None, alpha is rank (no scaling). """ super().__init__() self.lora_name = lora_name - self.lora_dim = lora_dim if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features - self.lora_dim = min(self.lora_dim, in_dim, out_dim) - if self.lora_dim != lora_dim: - print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + if org_module.__class__.__name__ == 'Conv2d': kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: - in_dim = org_module.in_features - out_dim = org_module.out_features - self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) - self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = lora_dim if alpha is None or alpha == 0 else alpha + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える @@ -149,12 +153,13 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un return network -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file, safe_open - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location='cpu') +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import load_file, safe_open + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location='cpu') # get dim/alpha mapping modules_dim = {} @@ -174,7 +179,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa # support old LoRA without alpha for key in modules_dim.keys(): if key not in modules_alpha: - modules_alpha = modules_dim[key] + modules_alpha = modules_dim[key] network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) network.weights_sd = weights_sd @@ -183,7 +188,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa class LoRANetwork(torch.nn.Module): # is it possible to apply conv_in and conv_out? - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' @@ -245,7 +251,12 @@ class LoRANetwork(torch.nn.Module): text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") self.weights_sd = None @@ -371,7 +382,7 @@ class LoRANetwork(torch.nn.Module): else: torch.save(state_dict, file) - @staticmethod + @ staticmethod def set_regions(networks, image): image = image.astype(np.float32) / 255.0 for i, network in enumerate(networks[:3]): diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 271de8e..09a19c1 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -1,14 +1,15 @@ # Convert LoRA to different rank approximation (should only be used to go to lower rank) # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py -# Thanks to cloneofsimo and kohya +# Thanks to cloneofsimo import argparse -import os import torch from safetensors.torch import load_file, save_file, safe_open from tqdm import tqdm from library import train_util, model_util +import numpy as np +MIN_SV = 1e-6 def load_state_dict(file_name, dtype): if model_util.is_safetensors(file_name): @@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) -def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0)/original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + if index >= len(S): + index = len(S) - 1 + + return index + + +def index_sv_fro(S, target): + S_squared = S.pow(2) + s_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + if index >= len(S): + index = len(S) - 1 + + return index + + +# Modified from Kohaku-blueleaf's extract/merge functions +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size, kernel_size, _ = weight.size() + U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() + del U, S, Vh, weight + return param_dict + + +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size = weight.size() + + U, S, Vh = torch.linalg.svd(weight.to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() + del U, S, Vh, weight + return param_dict + + +def merge_conv(lora_down, lora_up, device): + in_rank, in_size, kernel_size, k_ = lora_down.shape + out_size, out_rank, _, _ = lora_up.shape + assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) + weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) + del lora_up, lora_down + return weight + + +def merge_linear(lora_down, lora_up, device): + in_rank, in_size = lora_down.shape + out_size, out_rank = lora_up.shape + assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + weight = lora_up @ lora_down + del lora_up, lora_down + return weight + + +def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): + param_dict = {} + + if dynamic_method=="sv_ratio": + # Calculate new dim and alpha based off ratio + max_sv = S[0] + min_sv = max_sv/dynamic_param + new_rank = max(torch.sum(S > min_sv).item(),1) + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_cumulative": + # Calculate new dim and alpha based off cumulative sum + new_rank = index_sv_cumulative(S, dynamic_param) + new_rank = max(new_rank, 1) + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_fro": + # Calculate new dim and alpha based off sqrt sum of squares + new_rank = index_sv_fro(S, dynamic_param) + new_rank = min(max(new_rank, 1), len(S)-1) + new_alpha = float(scale*new_rank) + else: + new_rank = rank + new_alpha = float(scale*new_rank) + + + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + new_rank = 1 + new_alpha = float(scale*new_rank) + elif new_rank > rank: # cap max rank at rank + new_rank = rank + new_alpha = float(scale*new_rank) + + + # Calculate resize info + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) + fro_percent = float(s_red_fro/s_fro) + + param_dict["new_rank"] = new_rank + param_dict["new_alpha"] = new_alpha + param_dict["sum_retained"] = (s_rank)/s_sum + param_dict["fro_retained"] = fro_percent + param_dict["max_ratio"] = S[0]/S[new_rank] + + return param_dict + + +def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): network_alpha = None network_dim = None verbose_str = "\n" - - CLAMP_QUANTILE = 0.99 + fro_list = [] # Extract loaded lora dim and alpha for key, value in lora_sd.items(): @@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): network_alpha = network_dim scale = network_alpha/network_dim - new_alpha = float(scale*new_rank) # calculate new alpha from scale - print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}") + if dynamic_method: + print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") lora_down_weight = None lora_up_weight = None @@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): block_down_name = None block_up_name = None - print("resizing lora...") with torch.no_grad(): for key, value in tqdm(lora_sd.items()): if 'lora_down' in key: @@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): conv2d = (len(lora_down_weight.size()) == 4) if conv2d: - lora_down_weight = lora_down_weight.squeeze() - lora_up_weight = lora_up_weight.squeeze() - - if device: - org_device = lora_up_weight.device - lora_up_weight = lora_up_weight.to(args.device) - lora_down_weight = lora_down_weight.to(args.device) - - full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) - - U, S, Vh = torch.linalg.svd(full_weight_matrix) + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) if verbose: - s_sum = torch.sum(torch.abs(S)) - s_rank = torch.sum(torch.abs(S[:new_rank])) - verbose_str+=f"{block_down_name:76} | " - verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n" + max_ratio = param_dict['max_ratio'] + sum_retained = param_dict['sum_retained'] + fro_retained = param_dict['fro_retained'] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) - U = U[:, :new_rank] - S = S[:new_rank] - U = U @ torch.diag(S) + verbose_str+=f"{block_down_name:75} | " + verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" - Vh = Vh[:new_rank, :] + if verbose and dynamic_method: + verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str+=f"\n" - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val - - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) - - if conv2d: - U = U.unsqueeze(2).unsqueeze(3) - Vh = Vh.unsqueeze(2).unsqueeze(3) - - if device: - U = U.to(org_device) - Vh = Vh.to(org_device) - - o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype) + new_alpha = param_dict['new_alpha'] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) block_down_name = None block_up_name = None lora_down_weight = None lora_up_weight = None weights_loaded = False + del param_dict if verbose: print(verbose_str) + + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") print("resizing complete") return o_lora_sd, network_dim, new_alpha @@ -151,6 +274,9 @@ def resize(args): return torch.bfloat16 return None + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 save_dtype = str_to_dtype(args.save_precision) if save_dtype is None: @@ -159,17 +285,23 @@ def resize(args): print("loading Model...") lora_sd, metadata = load_state_dict(args.model, merge_dtype) - print("resizing rank...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose) + print("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) # update metadata if metadata is None: metadata = {} comment = metadata.get("ss_training_comment", "") - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) + + if not args.dynamic_method: + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + metadata["ss_network_dim"] = 'Dynamic' + metadata["ss_network_alpha"] = 'Dynamic' model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -193,6 +325,11 @@ if __name__ == '__main__': parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する") + parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") + parser.add_argument("--dynamic_param", type=float, default=None, + help="Specify target for dynamic reduction") + args = parser.parse_args() resize(args) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index c8e39b8..d907b43 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype): return sd -def save_to_file(file_name, model, state_dict, dtype): +def save_to_file(file_name, state_dict, dtype): if dtype is not None: for key in list(state_dict.keys()): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) + save_file(state_dict, file_name) else: - torch.save(model, file_name) + torch.save(state_dict, file_name) def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): @@ -76,7 +76,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty down_weight = down_weight.to(device) # W <- W + U * D - scale = (alpha / network_dim) + scale = (alpha / network_dim).to(device) if not conv2d: # linear weight = weight + ratio * (up_weight @ down_weight) * scale elif kernel_size == (1, 1): @@ -105,6 +105,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty mat = mat.squeeze() module_new_rank = new_conv_rank if conv2d_3x3 else new_rank + module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim U, S, Vh = torch.linalg.svd(mat) @@ -114,12 +115,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty Vh = Vh[:module_new_rank, :] - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val + # dist = torch.cat([U.flatten(), Vh.flatten()]) + # hi_val = torch.quantile(dist, CLAMP_QUANTILE) + # low_val = -hi_val - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) + # U = U.clamp(low_val, hi_val) + # Vh = Vh.clamp(low_val, hi_val) if conv2d: U = U.reshape(out_dim, module_new_rank, 1, 1) @@ -156,7 +157,7 @@ def merge(args): state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype) + save_to_file(args.save_to, state_dict, save_dtype) if __name__ == '__main__': diff --git a/train_README-ja.md b/train_README-ja.md index 479f960..d5f1b5f 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -502,6 +502,14 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。 +- `--persistent_data_loader_workers` + + Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。 + +- `--max_data_loader_n_workers` + + データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。 + - `--logging_dir` / `--log_prefix` 学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。 diff --git a/train_network.py b/train_network.py index cf64c89..5aa8af4 100644 --- a/train_network.py +++ b/train_network.py @@ -106,6 +106,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -134,6 +135,8 @@ def train(args): gc.collect() # prepare network + import sys + sys.path.append(os.path.dirname(__file__)) print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module) @@ -175,12 +178,13 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) + if is_main_process: + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps, num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする @@ -251,15 +255,17 @@ def train(args): # 学習する # TODO: find a way to handle total batch size when there are multiple datasets total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + if is_main_process: + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") # TODO refactor metadata creation and move to util metadata = { @@ -471,7 +477,8 @@ def train(args): loss_list = [] loss_total = 0.0 for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") + if is_main_process: + print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) metadata["ss_epoch"] = str(epoch+1) @@ -583,9 +590,10 @@ def train(args): print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + if is_main_process: + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -594,7 +602,6 @@ def train(args): metadata["ss_epoch"] = str(num_train_epochs) metadata["ss_training_finished_at"] = str(time.time()) - is_main_process = accelerator.is_main_process if is_main_process: network = unwrap_model(network) diff --git a/train_network_README-ja.md b/train_network_README-ja.md index 4a79a6f..79d1709 100644 --- a/train_network_README-ja.md +++ b/train_network_README-ja.md @@ -64,6 +64,10 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py * LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。 * `--network_alpha` * アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。 +* `--persistent_data_loader_workers` + * Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。 +* `--max_data_loader_n_workers` + * データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。 * `--network_weights` * 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。 * `--network_train_unet_only` From d1962d72400237b42d69312a989e59ba9b5e7508 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Fri, 10 Mar 2023 11:49:05 -0500 Subject: [PATCH 3/4] Switch to networks version of resize lora --- library/resize_lora_gui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index 6b4396b..527ff67 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -48,7 +48,7 @@ def resize_lora( if device == '': device = 'cuda' - run_cmd = f'{PYTHON} "{os.path.join("tools","resize_lora.py")}"' + run_cmd = f'{PYTHON} "{os.path.join("networks","resize_lora.py")}"' run_cmd += f' --save_precision {save_precision}' run_cmd += f' --save_to {save_to}' run_cmd += f' --model {model}' From a65555ea67c8e1519977cb91bfd9ba648350ee51 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Fri, 10 Mar 2023 20:05:38 -0500 Subject: [PATCH 4/4] Add support to load a config without opening the UI to get the file name --- dreambooth_gui.py | 19 ++++++++++++++++--- finetune_gui.py | 31 ++++++++++++++++++++++--------- library/common_gui.py | 3 +++ lora_gui.py | 19 ++++++++++++++++--- textual_inversion_gui.py | 19 ++++++++++++++++--- 5 files changed, 73 insertions(+), 18 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index df40784..dee017c 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -152,6 +152,7 @@ def save_configuration( def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -213,9 +214,13 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -231,7 +236,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) return tuple(values) @@ -506,6 +511,7 @@ def dreambooth_tab( button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -775,7 +781,14 @@ def dreambooth_tab( button_open_config.click( open_configuration, - inputs=[config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) diff --git a/finetune_gui.py b/finetune_gui.py index 3ef1cbd..59dffd8 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -149,7 +149,8 @@ def save_configuration( return file_path -def open_config_file( +def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -217,9 +218,13 @@ def open_config_file( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -235,7 +240,7 @@ def open_config_file( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) return tuple(values) @@ -492,8 +497,8 @@ def remove_doublequote(file_path): def finetune_tab(): - dummy_ft_true = gr.Label(value=True, visible=False) - dummy_ft_false = gr.Label(value=False, visible=False) + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) gr.Markdown('Train a custom model using kohya finetune python code...') ( @@ -501,6 +506,7 @@ def finetune_tab(): button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -770,22 +776,29 @@ def finetune_tab(): button_run.click(train_model, inputs=settings_list) button_open_config.click( - open_config_file, - inputs=[config_file_name] + settings_list, + open_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_save_config.click( save_configuration, - inputs=[dummy_ft_false, config_file_name] + settings_list, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name], show_progress=False, ) button_save_as_config.click( save_configuration, - inputs=[dummy_ft_true, config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name], show_progress=False, ) diff --git a/library/common_gui.py b/library/common_gui.py index b22594f..e200141 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -405,11 +405,14 @@ def gradio_config(): placeholder="type the configuration file path or use the 'Open' button above to select it...", interactive=True, ) + button_load_config = gr.Button('Load 💾', elem_id='open_folder') + config_file_name.change(remove_doublequote, inputs=[config_file_name], outputs=[config_file_name]) return ( button_open_config, button_save_config, button_save_as_config, config_file_name, + button_load_config, ) diff --git a/lora_gui.py b/lora_gui.py index 49918de..23da712 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -168,6 +168,7 @@ def save_configuration( def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -239,9 +240,13 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -257,7 +262,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) # This next section is about making the LoCon parameters visible if LoRA_type = 'Standard' @@ -610,6 +615,7 @@ def lora_tab( button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -974,7 +980,14 @@ def lora_tab( button_open_config.click( open_configuration, - inputs=[config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list + [LoCon_row], + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, ) diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index c92bdc0..ed3c33a 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -158,6 +158,7 @@ def save_configuration( def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -225,9 +226,13 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -243,7 +248,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) return tuple(values) @@ -548,6 +553,7 @@ def ti_tab( button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -865,7 +871,14 @@ def ti_tab( button_open_config.click( open_configuration, - inputs=[config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, )