diff --git a/README.md b/README.md index c64f3b0..93b517e 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i ## Change history +* 2023/01/06 (v19.4): + - Add new Utility to Extract a LoRA from a finetuned model * 2023/01/06 (v19.3.1): - Emergency fix for dreambooth_ui no longer working, sorry - Add LoRA network merge too GUI. Run `pip install -U -r requirements.txt` after pulling this new release. diff --git a/README_train_network-ja.md b/README_train_network-ja.md index bba4293..1ad1b7a 100644 --- a/README_train_network-ja.md +++ b/README_train_network-ja.md @@ -10,9 +10,7 @@ cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。 -WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルに、このリポジトリ内のスクリプトであらかじめマージしておく必要があります。マージ後のモデルファイルはLoRAの学習結果が反映されたものになります。 - -なお当リポジトリ内の画像生成スクリプトで生成する場合はマージ不要です。 +WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extention](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。 ## 学習方法 @@ -24,7 +22,7 @@ DreamBoothの手法(identifier(sksなど)とclass、オプションで正 ### DreamBoothの手法を用いる場合 -note.com [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nba4eceaa4594) を参照してデータを用意してください。 +[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。 学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。 @@ -110,7 +108,7 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt ### 複数のLoRAのモデルをマージする -結局のところSDモデルにマージしないと推論できないのであまり使い道はないかもしれません。ただ、複数のLoRAモデルをひとつずつSDモデルにマージしていく場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。 +複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。 たとえば以下のようなコマンドラインになります。 @@ -144,6 +142,40 @@ gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim --network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。 +## 二つのモデルの差分からLoRAモデルを作成する + +[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。 + +二つのモデル(たとえばfine tuningの元モデルとfine tuning後のモデル)の差分を、LoRAで近似します。 + +### スクリプトの実行方法 + +以下のように指定してください。 +``` +python networks\extract_lora_from_models.py --model_org base-model.ckpt + --model_tuned fine-tuned-model.ckpt + --save_to lora-weights.safetensors --dim 4 +``` + +--model_orgオプションに元のStable Diffusionモデルを指定します。作成したLoRAモデルを適用する場合は、このモデルを指定して適用することになります。.ckptまたは.safetensorsが指定できます。 + +--model_tunedオプションに差分を抽出する対象のStable Diffusionモデルを指定します。たとえばfine tuningやDreamBooth後のモデルを指定します。.ckptまたは.safetensorsが指定できます。 + +--save_toにLoRAモデルの保存先を指定します。--dimにLoRAの次元数を指定します。 + +生成されたLoRAモデルは、学習したLoRAモデルと同様に使用できます。 + +Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRAとなります。 + +### その他のオプション + +- --v2 + - v2.xのStable Diffusionモデルを使う場合に指定してください。 +- --device + - ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。 +- --save_precision + - LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。 + ## 追加情報 ### cloneofsimo氏のリポジトリとの違い @@ -154,4 +186,4 @@ gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim ### 将来拡張について -LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。 +LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。 \ No newline at end of file diff --git a/dreambooth_gui.py b/dreambooth_gui.py index fdff172..e2166a7 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -22,7 +22,6 @@ from library.common_gui import ( from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) -from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab from easygui import msgbox @@ -398,13 +397,13 @@ def train_model( if flip_aug: run_cmd += ' --flip_aug' run_cmd += ( - f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' + f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' ) run_cmd += f' --train_data_dir="{train_data_dir}"' if len(reg_data_dir): run_cmd += f' --reg_data_dir="{reg_data_dir}"' run_cmd += f' --resolution={max_resolution}' - run_cmd += f' --output_dir={output_dir}' + run_cmd += f' --output_dir="{output_dir}"' run_cmd += f' --train_batch_size={train_batch_size}' run_cmd += f' --learning_rate={learning_rate}' run_cmd += f' --lr_scheduler={lr_scheduler}' @@ -416,7 +415,7 @@ def train_model( run_cmd += f' --save_every_n_epochs={save_every_n_epochs}' run_cmd += f' --seed={seed}' run_cmd += f' --save_precision={save_precision}' - run_cmd += f' --logging_dir={logging_dir}' + run_cmd += f' --logging_dir="{logging_dir}"' if not caption_extension == '': run_cmd += f' --caption_extension={caption_extension}' if not stop_text_encoder_training == 0: @@ -817,7 +816,6 @@ def dreambooth_tab( output_dir_input=output_dir_input, logging_dir_input=logging_dir_input, ) - gradio_dataset_balancing_tab() button_run = gr.Button('Train model') diff --git a/finetune_gui.py b/finetune_gui.py index 5bb3eeb..9c8e8a5 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -276,8 +276,8 @@ def train_model( run_cmd += f' --caption_extension=".txt"' else: run_cmd += f' --caption_extension={caption_extension}' - run_cmd += f' {image_folder}' - run_cmd += f' {train_dir}/{caption_metadata_filename}' + run_cmd += f' "{image_folder}"' + run_cmd += f' "{train_dir}/{caption_metadata_filename}"' if full_path: run_cmd += f' --full_path' @@ -291,10 +291,10 @@ def train_model( run_cmd = ( f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py' ) - run_cmd += f' {image_folder}' - run_cmd += f' {train_dir}/{caption_metadata_filename}' - run_cmd += f' {train_dir}/{latent_metadata_filename}' - run_cmd += f' {pretrained_model_name_or_path}' + run_cmd += f' "{image_folder}"' + run_cmd += f' "{train_dir}/{caption_metadata_filename}"' + run_cmd += f' "{train_dir}/{latent_metadata_filename}"' + run_cmd += f' "{pretrained_model_name_or_path}"' run_cmd += f' --batch_size={batch_size}' run_cmd += f' --max_resolution={max_resolution}' run_cmd += f' --min_bucket_reso={min_bucket_reso}' @@ -344,13 +344,13 @@ def train_model( if xformers: run_cmd += f' --xformers' run_cmd += ( - f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' + f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' ) - run_cmd += f' --in_json={train_dir}/{latent_metadata_filename}' - run_cmd += f' --train_data_dir={image_folder}' - run_cmd += f' --output_dir={output_dir}' + run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"' + run_cmd += f' --train_data_dir="{image_folder}"' + run_cmd += f' --output_dir="{output_dir}"' if not logging_dir == '': - run_cmd += f' --logging_dir={logging_dir}' + run_cmd += f' --logging_dir="{logging_dir}"' run_cmd += f' --train_batch_size={train_batch_size}' run_cmd += f' --dataset_repeats={dataset_repeats}' run_cmd += f' --learning_rate={learning_rate}' diff --git a/kohya_gui.py b/kohya_gui.py index 8bb1aa6..5224be2 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -4,6 +4,8 @@ import argparse from dreambooth_gui import dreambooth_tab from finetune_gui import finetune_tab from library.utilities import utilities_tab +from library.extract_lora_gui import gradio_extract_lora_tab +from library.merge_lora_gui import gradio_merge_lora_tab from lora_gui import lora_tab @@ -38,6 +40,8 @@ def UI(username, password): logging_dir_input=logging_dir_input, enable_copy_info_button=True, ) + gradio_extract_lora_tab() + gradio_merge_lora_tab() # Show the interface if not username == '': diff --git a/library/common_gui.py b/library/common_gui.py index c30c0d3..f54267b 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -3,17 +3,22 @@ import os import gradio as gr from easygui import msgbox +def get_dir_and_file(file_path): + dir_path, file_name = os.path.split(file_path) + return (dir_path, file_name) -def get_file_path(file_path='', defaultextension='.json'): +def get_file_path(file_path='', defaultextension='.json', extension_name='Config files'): current_file_path = file_path # print(f'current file path: {current_file_path}') + + initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() file_path = filedialog.askopenfilename( - filetypes=(('Config files', '*.json'), ('All files', '*')), - defaultextension=defaultextension, + filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')), + defaultextension=defaultextension, initialfile=initial_file, initialdir=initial_dir ) root.destroy() @@ -25,11 +30,14 @@ def get_file_path(file_path='', defaultextension='.json'): def get_any_file_path(file_path=''): current_file_path = file_path # print(f'current file path: {current_file_path}') + + initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() - file_path = filedialog.askopenfilename() + file_path = filedialog.askopenfilename(initialdir=initial_dir, + initialfile=initial_file,) root.destroy() if file_path == '': @@ -47,11 +55,13 @@ def remove_doublequote(file_path): def get_folder_path(folder_path=''): current_folder_path = folder_path + + initial_dir, initial_file = get_dir_and_file(folder_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() - folder_path = filedialog.askdirectory() + folder_path = filedialog.askdirectory(initialdir=initial_dir) root.destroy() if folder_path == '': @@ -60,16 +70,20 @@ def get_folder_path(folder_path=''): return folder_path -def get_saveasfile_path(file_path='', defaultextension='.json'): +def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='Config files'): current_file_path = file_path # print(f'current file path: {current_file_path}') + + initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() save_file_path = filedialog.asksaveasfile( - filetypes=(('Config files', '*.json'), ('All files', '*')), + filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')), defaultextension=defaultextension, + initialdir=initial_dir, + initialfile=initial_file, ) root.destroy() @@ -85,6 +99,30 @@ def get_saveasfile_path(file_path='', defaultextension='.json'): return file_path +def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config files'): + current_file_path = file_path + # print(f'current file path: {current_file_path}') + + initial_dir, initial_file = get_dir_and_file(file_path) + + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + save_file_path = filedialog.asksaveasfilename(filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')), + defaultextension=extensions, + initialdir=initial_dir, + initialfile=initial_file, + ) + root.destroy() + + if save_file_path == '': + file_path = current_file_path + else: + # print(save_file_path) + file_path = save_file_path + + return file_path + def add_pre_postfix( folder='', prefix='', postfix='', caption_file_ext='.caption' diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py new file mode 100644 index 0000000..b058c0f --- /dev/null +++ b/library/extract_lora_gui.py @@ -0,0 +1,127 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +document_symbol = '\U0001F4C4' # 📄 + + +def extract_lora( + model_tuned, model_org, save_to, save_precision, dim, v2, +): + # Check for caption_text_input + if model_tuned == '': + msgbox('Invalid finetuned model file') + return + + if model_org == '': + msgbox('Invalid base model file') + return + + # Check if source model exist + if not os.path.isfile(model_tuned): + msgbox('The provided finetuned model is not a file') + return + + if not os.path.isfile(model_org): + msgbox('The provided base model is not a file') + return + + run_cmd = f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"' + run_cmd += f' --save_precision {save_precision}' + run_cmd += f' --save_to "{save_to}"' + run_cmd += f' --model_org "{model_org}"' + run_cmd += f' --model_tuned "{model_tuned}"' + run_cmd += f' --dim {dim}' + if v2: + run_cmd += f' --v2' + + print(run_cmd) + + # Run the command + subprocess.run(run_cmd) + + +### +# Gradio UI +### + + +def gradio_extract_lora_tab(): + with gr.Tab('Extract LoRA'): + gr.Markdown( + 'This utility can extract a LoRA network from a finetuned model.' + ) + lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False) + model_ext_name = gr.Textbox(value='Model types', visible=False) + + with gr.Row(): + model_tuned = gr.Textbox( + label='Finetuned model', + placeholder='Path to the finetuned model to extract', + interactive=True, + ) + button_model_tuned_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_model_tuned_file.click( + get_file_path, + inputs=[model_tuned, model_ext, model_ext_name], + outputs=model_tuned, + ) + + model_org = gr.Textbox( + label='Stable Diffusion base model', + placeholder='Stable Diffusion original model: ckpt or safetensors file', + interactive=True, + ) + button_model_org_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_model_org_file.click( + get_file_path, + inputs=[model_org, model_ext, model_ext_name], + outputs=model_org, + ) + with gr.Row(): + save_to = gr.Textbox( + label='Save to', + placeholder='path where to save the extracted LoRA model...', + interactive=True, + ) + button_save_to = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_save_to.click( + get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to + ) + save_precision = gr.Dropdown( + label='Save precison', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + with gr.Row(): + dim = gr.Slider( + minimum=1, + maximum=128, + label='Network Dimension', + value=8, + step=1, + interactive=True, + ) + v2 = gr.Checkbox(label='v2', value=False, interactive=True) + + extract_button = gr.Button('Extract LoRA model') + + extract_button.click( + extract_lora, + inputs=[model_tuned, model_org, save_to, save_precision, dim, v2 + ], + ) diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index 381d411..0271963 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -2,7 +2,7 @@ import gradio as gr from easygui import msgbox import subprocess import os -from .common_gui import get_folder_path, get_any_file_path +from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -55,29 +55,11 @@ def merge_lora( def gradio_merge_lora_tab(): with gr.Tab('Merge LoRA'): gr.Markdown( - 'This utility can merge LoRA networks.' + 'This utility can merge two LoRA networks together.' ) - # with gr.Row(): - # sd_model = gr.Textbox( - # label='Stable Diffusion model', - # placeholder='(Optional) only select if mergind a LoRA into a ckpt or tensorflow model', - # interactive=True, - # ) - # button_sd_model_dir = gr.Button( - # folder_symbol, elem_id='open_folder_small' - # ) - # button_sd_model_dir.click( - # get_folder_path, outputs=sd_model - # ) - - # button_sd_model_file = gr.Button( - # document_symbol, elem_id='open_folder_small' - # ) - # button_sd_model_file.click( - # get_any_file_path, - # inputs=[sd_model], - # outputs=sd_model, - # ) + + lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) with gr.Row(): lora_a_model = gr.Textbox( @@ -86,11 +68,11 @@ def gradio_merge_lora_tab(): interactive=True, ) button_lora_a_model_file = gr.Button( - document_symbol, elem_id='open_folder_small' + folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - get_any_file_path, - inputs=[lora_a_model], + get_file_path, + inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, ) @@ -100,11 +82,11 @@ def gradio_merge_lora_tab(): interactive=True, ) button_lora_b_model_file = gr.Button( - document_symbol, elem_id='open_folder_small' + folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - get_any_file_path, - inputs=[lora_b_model], + get_file_path, + inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, ) with gr.Row(): @@ -121,7 +103,7 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_any_file_path, inputs=save_to, outputs=save_to + get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to ) precision = gr.Dropdown( label='Merge precison', diff --git a/lora_gui.py b/lora_gui.py index 2e0d865..ea4e133 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -426,13 +426,13 @@ def train_model( if flip_aug: run_cmd += ' --flip_aug' run_cmd += ( - f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' + f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' ) run_cmd += f' --train_data_dir="{train_data_dir}"' if len(reg_data_dir): run_cmd += f' --reg_data_dir="{reg_data_dir}"' run_cmd += f' --resolution={max_resolution}' - run_cmd += f' --output_dir={output_dir}' + run_cmd += f' --output_dir="{output_dir}"' run_cmd += f' --train_batch_size={train_batch_size}' # run_cmd += f' --learning_rate={learning_rate}' run_cmd += f' --lr_scheduler={lr_scheduler}' @@ -444,7 +444,7 @@ def train_model( run_cmd += f' --save_every_n_epochs={save_every_n_epochs}' run_cmd += f' --seed={seed}' run_cmd += f' --save_precision={save_precision}' - run_cmd += f' --logging_dir={logging_dir}' + run_cmd += f' --logging_dir="{logging_dir}"' if not caption_extension == '': run_cmd += f' --caption_extension={caption_extension}' if not stop_text_encoder_training == 0: @@ -454,7 +454,7 @@ def train_model( if not save_model_as == 'same as source model': run_cmd += f' --save_model_as={save_model_as}' if not resume == '': - run_cmd += f' --resume={resume}' + run_cmd += f' --resume="{resume}"' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --network_module=networks.lora' @@ -472,7 +472,7 @@ def train_model( # run_cmd += f' --network_train_unet_only' run_cmd += f' --network_dim={network_dim}' if not lora_network_weights == '': - run_cmd += f' --network_weights={lora_network_weights}' + run_cmd += f' --network_weights="{lora_network_weights}"' if int(clip_skip) > 1: run_cmd += f' --clip_skip={str(clip_skip)}' @@ -756,33 +756,23 @@ def lora_tab( 'linear', 'polynomial', ], - value='constant', + value='cosine', ) - lr_warmup_input = gr.Textbox(label='LR warmup', value=0) + lr_warmup_input = gr.Textbox(label='LR warmup (% of steps)', value=10) with gr.Row(): text_encoder_lr = gr.Textbox( label='Text Encoder learning rate', - value=1e-6, + value="5e-5", placeholder='Optional', ) unet_lr = gr.Textbox( - label='Unet learning rate', value=1e-4, placeholder='Optional' + label='Unet learning rate', value="1e-3", placeholder='Optional' ) - # network_train = gr.Dropdown( - # label='Network to train', - # choices=[ - # 'Text encoder and Unet', - # 'Text encoder only', - # 'Unet only', - # ], - # value='Text encoder and Unet', - # interactive=True - # ) network_dim = gr.Slider( minimum=1, maximum=128, label='Network Dimension', - value=4, + value=8, step=1, interactive=True, ) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py new file mode 100644 index 0000000..ae586f1 --- /dev/null +++ b/networks/extract_lora_from_models.py @@ -0,0 +1,158 @@ +# extract approximating LoRA by svd from two SD models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import library.model_util as model_util +import lora + + +CLAMP_QUANTILE = 0.99 +MIN_DIFF = 1e-6 + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def svd(args): + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + save_dtype = str_to_dtype(args.save_precision) + + print(f"loading SD model : {args.model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + print(f"loading SD model : {args.model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + + # create LoRA network to extract weights + lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o) + lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t) + assert len(lora_network_o.text_encoder_loras) == len( + lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " + + # get diffs + diffs = {} + text_encoder_different = False + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + + # Text Encoder might be same + if torch.max(torch.abs(diff)) > MIN_DIFF: + text_encoder_different = True + + diff = diff.float() + diffs[lora_name] = diff + + if not text_encoder_different: + print("Text encoder is same. Extract U-Net only.") + lora_network_o.text_encoder_loras = [] + diffs = {} + + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + diff = diff.float() + + if args.device: + diff = diff.to(args.device) + + diffs[lora_name] = diff + + # make LoRA with svd + print("calculating by svd") + rank = args.dim + lora_weights = {} + with torch.no_grad(): + for lora_name, mat in tqdm(list(diffs.items())): + conv2d = (len(mat.size()) == 4) + if conv2d: + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + lora_weights[lora_name] = (U, Vh) + + # make state dict for LoRA + lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict + lora_sd = lora_network_o.state_dict() + print(f"LoRA has {len(lora_sd)} weights.") + + for key in list(lora_sd.keys()): + lora_name = key.split('.')[0] + i = 0 if "lora_up" in key else 1 + + weights = lora_weights[lora_name][i] + # print(key, i, weights.size(), lora_sd[key].size()) + if len(lora_sd[key].size()) == 4: + weights = weights.unsqueeze(2).unsqueeze(3) + + assert weights.size() == lora_sd[key].size() + lora_sd[key] = weights + + # load state dict to LoRA and save it + info = lora_network_o.load_state_dict(lora_sd) + print(f"Loading extracted LoRA weights: {info}") + + dir_name = os.path.dirname(args.save_to) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + lora_network_o.save_weights(args.save_to, save_dtype) + print(f"LoRA weights are saved to: {args.save_to}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat") + parser.add_argument("--model_org", type=str, default=None, + help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors") + parser.add_argument("--model_tuned", type=str, default=None, + help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数(デフォルト4)") + parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う") + + args = parser.parse_args() + svd(args) \ No newline at end of file