Add new Utility to Extract a LoRA from a finetuned model

This commit is contained in:
bmaltais 2023-01-06 18:25:55 -05:00
parent c20a10d7fd
commit 34f7cd8e57
10 changed files with 410 additions and 79 deletions

View File

@ -30,6 +30,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i
## Change history
* 2023/01/06 (v19.4):
- Add new Utility to Extract a LoRA from a finetuned model
* 2023/01/06 (v19.3.1):
- Emergency fix for dreambooth_ui no longer working, sorry
- Add LoRA network merge too GUI. Run `pip install -U -r requirements.txt` after pulling this new release.

View File

@ -10,9 +10,7 @@
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルに、このリポジトリ内のスクリプトであらかじめマージしておく必要があります。マージ後のモデルファイルはLoRAの学習結果が反映されたものになります。
なお当リポジトリ内の画像生成スクリプトで生成する場合はマージ不要です。
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extention](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
## 学習方法
@ -24,7 +22,7 @@ DreamBoothの手法identifiersksなどとclass、オプションで正
### DreamBoothの手法を用いる場合
note.com [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nba4eceaa4594) を参照してデータを用意してください。
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。
@ -110,7 +108,7 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt
### 複数のLoRAのモデルをマージする
結局のところSDモデルにマージしないと推論できないのであまり使い道はないかもしれません。ただ、複数のLoRAモデルをひとつずつSDモデルにマージしていく場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
たとえば以下のようなコマンドラインになります。
@ -144,6 +142,40 @@ gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
## 二つのモデルの差分からLoRAモデルを作成する
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。
二つのモデルたとえばfine tuningの元モデルとfine tuning後のモデルの差分を、LoRAで近似します。
### スクリプトの実行方法
以下のように指定してください。
```
python networks\extract_lora_from_models.py --model_org base-model.ckpt
--model_tuned fine-tuned-model.ckpt
--save_to lora-weights.safetensors --dim 4
```
--model_orgオプションに元のStable Diffusionモデルを指定します。作成したLoRAモデルを適用する場合は、このモデルを指定して適用することになります。.ckptまたは.safetensorsが指定できます。
--model_tunedオプションに差分を抽出する対象のStable Diffusionモデルを指定します。たとえばfine tuningやDreamBooth後のモデルを指定します。.ckptまたは.safetensorsが指定できます。
--save_toにLoRAモデルの保存先を指定します。--dimにLoRAの次元数を指定します。
生成されたLoRAモデルは、学習したLoRAモデルと同様に使用できます。
Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRAとなります。
### その他のオプション
- --v2
- v2.xのStable Diffusionモデルを使う場合に指定してください。
- --device
- ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなりますCPUでもそこまで遅くないため、せいぜい倍数倍程度のようです
- --save_precision
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
## 追加情報
### cloneofsimo氏のリポジトリとの違い
@ -154,4 +186,4 @@ gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim
### 将来拡張について
LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。
LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。

View File

@ -22,7 +22,6 @@ from library.common_gui import (
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab
from easygui import msgbox
@ -398,13 +397,13 @@ def train_model(
if flip_aug:
run_cmd += ' --flip_aug'
run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
)
run_cmd += f' --train_data_dir="{train_data_dir}"'
if len(reg_data_dir):
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
run_cmd += f' --resolution={max_resolution}'
run_cmd += f' --output_dir={output_dir}'
run_cmd += f' --output_dir="{output_dir}"'
run_cmd += f' --train_batch_size={train_batch_size}'
run_cmd += f' --learning_rate={learning_rate}'
run_cmd += f' --lr_scheduler={lr_scheduler}'
@ -416,7 +415,7 @@ def train_model(
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
run_cmd += f' --seed={seed}'
run_cmd += f' --save_precision={save_precision}'
run_cmd += f' --logging_dir={logging_dir}'
run_cmd += f' --logging_dir="{logging_dir}"'
if not caption_extension == '':
run_cmd += f' --caption_extension={caption_extension}'
if not stop_text_encoder_training == 0:
@ -817,7 +816,6 @@ def dreambooth_tab(
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
)
gradio_dataset_balancing_tab()
button_run = gr.Button('Train model')

View File

@ -276,8 +276,8 @@ def train_model(
run_cmd += f' --caption_extension=".txt"'
else:
run_cmd += f' --caption_extension={caption_extension}'
run_cmd += f' {image_folder}'
run_cmd += f' {train_dir}/{caption_metadata_filename}'
run_cmd += f' "{image_folder}"'
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
if full_path:
run_cmd += f' --full_path'
@ -291,10 +291,10 @@ def train_model(
run_cmd = (
f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py'
)
run_cmd += f' {image_folder}'
run_cmd += f' {train_dir}/{caption_metadata_filename}'
run_cmd += f' {train_dir}/{latent_metadata_filename}'
run_cmd += f' {pretrained_model_name_or_path}'
run_cmd += f' "{image_folder}"'
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
run_cmd += f' "{pretrained_model_name_or_path}"'
run_cmd += f' --batch_size={batch_size}'
run_cmd += f' --max_resolution={max_resolution}'
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
@ -344,13 +344,13 @@ def train_model(
if xformers:
run_cmd += f' --xformers'
run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
)
run_cmd += f' --in_json={train_dir}/{latent_metadata_filename}'
run_cmd += f' --train_data_dir={image_folder}'
run_cmd += f' --output_dir={output_dir}'
run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"'
run_cmd += f' --train_data_dir="{image_folder}"'
run_cmd += f' --output_dir="{output_dir}"'
if not logging_dir == '':
run_cmd += f' --logging_dir={logging_dir}'
run_cmd += f' --logging_dir="{logging_dir}"'
run_cmd += f' --train_batch_size={train_batch_size}'
run_cmd += f' --dataset_repeats={dataset_repeats}'
run_cmd += f' --learning_rate={learning_rate}'

View File

@ -4,6 +4,8 @@ import argparse
from dreambooth_gui import dreambooth_tab
from finetune_gui import finetune_tab
from library.utilities import utilities_tab
from library.extract_lora_gui import gradio_extract_lora_tab
from library.merge_lora_gui import gradio_merge_lora_tab
from lora_gui import lora_tab
@ -38,6 +40,8 @@ def UI(username, password):
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
gradio_extract_lora_tab()
gradio_merge_lora_tab()
# Show the interface
if not username == '':

View File

@ -3,17 +3,22 @@ import os
import gradio as gr
from easygui import msgbox
def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name)
def get_file_path(file_path='', defaultextension='.json'):
def get_file_path(file_path='', defaultextension='.json', extension_name='Config files'):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
file_path = filedialog.askopenfilename(
filetypes=(('Config files', '*.json'), ('All files', '*')),
defaultextension=defaultextension,
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')),
defaultextension=defaultextension, initialfile=initial_file, initialdir=initial_dir
)
root.destroy()
@ -25,11 +30,14 @@ def get_file_path(file_path='', defaultextension='.json'):
def get_any_file_path(file_path=''):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
file_path = filedialog.askopenfilename()
file_path = filedialog.askopenfilename(initialdir=initial_dir,
initialfile=initial_file,)
root.destroy()
if file_path == '':
@ -47,11 +55,13 @@ def remove_doublequote(file_path):
def get_folder_path(folder_path=''):
current_folder_path = folder_path
initial_dir, initial_file = get_dir_and_file(folder_path)
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
folder_path = filedialog.askdirectory()
folder_path = filedialog.askdirectory(initialdir=initial_dir)
root.destroy()
if folder_path == '':
@ -60,16 +70,20 @@ def get_folder_path(folder_path=''):
return folder_path
def get_saveasfile_path(file_path='', defaultextension='.json'):
def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='Config files'):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
save_file_path = filedialog.asksaveasfile(
filetypes=(('Config files', '*.json'), ('All files', '*')),
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')),
defaultextension=defaultextension,
initialdir=initial_dir,
initialfile=initial_file,
)
root.destroy()
@ -85,6 +99,30 @@ def get_saveasfile_path(file_path='', defaultextension='.json'):
return file_path
def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config files'):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
save_file_path = filedialog.asksaveasfilename(filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
defaultextension=extensions,
initialdir=initial_dir,
initialfile=initial_file,
)
root.destroy()
if save_file_path == '':
file_path = current_file_path
else:
# print(save_file_path)
file_path = save_file_path
return file_path
def add_pre_postfix(
folder='', prefix='', postfix='', caption_file_ext='.caption'

127
library/extract_lora_gui.py Normal file
View File

@ -0,0 +1,127 @@
import gradio as gr
from easygui import msgbox
import subprocess
import os
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
def extract_lora(
model_tuned, model_org, save_to, save_precision, dim, v2,
):
# Check for caption_text_input
if model_tuned == '':
msgbox('Invalid finetuned model file')
return
if model_org == '':
msgbox('Invalid base model file')
return
# Check if source model exist
if not os.path.isfile(model_tuned):
msgbox('The provided finetuned model is not a file')
return
if not os.path.isfile(model_org):
msgbox('The provided base model is not a file')
return
run_cmd = f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"'
run_cmd += f' --save_precision {save_precision}'
run_cmd += f' --save_to "{save_to}"'
run_cmd += f' --model_org "{model_org}"'
run_cmd += f' --model_tuned "{model_tuned}"'
run_cmd += f' --dim {dim}'
if v2:
run_cmd += f' --v2'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
###
# Gradio UI
###
def gradio_extract_lora_tab():
with gr.Tab('Extract LoRA'):
gr.Markdown(
'This utility can extract a LoRA network from a finetuned model.'
)
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False)
model_ext_name = gr.Textbox(value='Model types', visible=False)
with gr.Row():
model_tuned = gr.Textbox(
label='Finetuned model',
placeholder='Path to the finetuned model to extract',
interactive=True,
)
button_model_tuned_file = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_model_tuned_file.click(
get_file_path,
inputs=[model_tuned, model_ext, model_ext_name],
outputs=model_tuned,
)
model_org = gr.Textbox(
label='Stable Diffusion base model',
placeholder='Stable Diffusion original model: ckpt or safetensors file',
interactive=True,
)
button_model_org_file = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_model_org_file.click(
get_file_path,
inputs=[model_org, model_ext, model_ext_name],
outputs=model_org,
)
with gr.Row():
save_to = gr.Textbox(
label='Save to',
placeholder='path where to save the extracted LoRA model...',
interactive=True,
)
button_save_to = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_save_to.click(
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
)
save_precision = gr.Dropdown(
label='Save precison',
choices=['fp16', 'bf16', 'float'],
value='float',
interactive=True,
)
with gr.Row():
dim = gr.Slider(
minimum=1,
maximum=128,
label='Network Dimension',
value=8,
step=1,
interactive=True,
)
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
extract_button = gr.Button('Extract LoRA model')
extract_button.click(
extract_lora,
inputs=[model_tuned, model_org, save_to, save_precision, dim, v2
],
)

View File

@ -2,7 +2,7 @@ import gradio as gr
from easygui import msgbox
import subprocess
import os
from .common_gui import get_folder_path, get_any_file_path
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@ -55,29 +55,11 @@ def merge_lora(
def gradio_merge_lora_tab():
with gr.Tab('Merge LoRA'):
gr.Markdown(
'This utility can merge LoRA networks.'
'This utility can merge two LoRA networks together.'
)
# with gr.Row():
# sd_model = gr.Textbox(
# label='Stable Diffusion model',
# placeholder='(Optional) only select if mergind a LoRA into a ckpt or tensorflow model',
# interactive=True,
# )
# button_sd_model_dir = gr.Button(
# folder_symbol, elem_id='open_folder_small'
# )
# button_sd_model_dir.click(
# get_folder_path, outputs=sd_model
# )
# button_sd_model_file = gr.Button(
# document_symbol, elem_id='open_folder_small'
# )
# button_sd_model_file.click(
# get_any_file_path,
# inputs=[sd_model],
# outputs=sd_model,
# )
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
with gr.Row():
lora_a_model = gr.Textbox(
@ -86,11 +68,11 @@ def gradio_merge_lora_tab():
interactive=True,
)
button_lora_a_model_file = gr.Button(
document_symbol, elem_id='open_folder_small'
folder_symbol, elem_id='open_folder_small'
)
button_lora_a_model_file.click(
get_any_file_path,
inputs=[lora_a_model],
get_file_path,
inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model,
)
@ -100,11 +82,11 @@ def gradio_merge_lora_tab():
interactive=True,
)
button_lora_b_model_file = gr.Button(
document_symbol, elem_id='open_folder_small'
folder_symbol, elem_id='open_folder_small'
)
button_lora_b_model_file.click(
get_any_file_path,
inputs=[lora_b_model],
get_file_path,
inputs=[lora_b_model, lora_ext, lora_ext_name],
outputs=lora_b_model,
)
with gr.Row():
@ -121,7 +103,7 @@ def gradio_merge_lora_tab():
folder_symbol, elem_id='open_folder_small'
)
button_save_to.click(
get_any_file_path, inputs=save_to, outputs=save_to
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
)
precision = gr.Dropdown(
label='Merge precison',

View File

@ -426,13 +426,13 @@ def train_model(
if flip_aug:
run_cmd += ' --flip_aug'
run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
)
run_cmd += f' --train_data_dir="{train_data_dir}"'
if len(reg_data_dir):
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
run_cmd += f' --resolution={max_resolution}'
run_cmd += f' --output_dir={output_dir}'
run_cmd += f' --output_dir="{output_dir}"'
run_cmd += f' --train_batch_size={train_batch_size}'
# run_cmd += f' --learning_rate={learning_rate}'
run_cmd += f' --lr_scheduler={lr_scheduler}'
@ -444,7 +444,7 @@ def train_model(
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
run_cmd += f' --seed={seed}'
run_cmd += f' --save_precision={save_precision}'
run_cmd += f' --logging_dir={logging_dir}'
run_cmd += f' --logging_dir="{logging_dir}"'
if not caption_extension == '':
run_cmd += f' --caption_extension={caption_extension}'
if not stop_text_encoder_training == 0:
@ -454,7 +454,7 @@ def train_model(
if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}'
if not resume == '':
run_cmd += f' --resume={resume}'
run_cmd += f' --resume="{resume}"'
if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
run_cmd += f' --network_module=networks.lora'
@ -472,7 +472,7 @@ def train_model(
# run_cmd += f' --network_train_unet_only'
run_cmd += f' --network_dim={network_dim}'
if not lora_network_weights == '':
run_cmd += f' --network_weights={lora_network_weights}'
run_cmd += f' --network_weights="{lora_network_weights}"'
if int(clip_skip) > 1:
run_cmd += f' --clip_skip={str(clip_skip)}'
@ -756,33 +756,23 @@ def lora_tab(
'linear',
'polynomial',
],
value='constant',
value='cosine',
)
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
lr_warmup_input = gr.Textbox(label='LR warmup (% of steps)', value=10)
with gr.Row():
text_encoder_lr = gr.Textbox(
label='Text Encoder learning rate',
value=1e-6,
value="5e-5",
placeholder='Optional',
)
unet_lr = gr.Textbox(
label='Unet learning rate', value=1e-4, placeholder='Optional'
label='Unet learning rate', value="1e-3", placeholder='Optional'
)
# network_train = gr.Dropdown(
# label='Network to train',
# choices=[
# 'Text encoder and Unet',
# 'Text encoder only',
# 'Unet only',
# ],
# value='Text encoder and Unet',
# interactive=True
# )
network_dim = gr.Slider(
minimum=1,
maximum=128,
label='Network Dimension',
value=4,
value=8,
step=1,
interactive=True,
)

View File

@ -0,0 +1,158 @@
# extract approximating LoRA by svd from two SD models
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo!
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-6
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def svd(args):
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
save_dtype = str_to_dtype(args.save_precision)
print(f"loading SD model : {args.model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
print(f"loading SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
# create LoRA network to extract weights
lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o)
lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t)
assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
# get diffs
diffs = {}
text_encoder_different = False
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
# Text Encoder might be same
if torch.max(torch.abs(diff)) > MIN_DIFF:
text_encoder_different = True
diff = diff.float()
diffs[lora_name] = diff
if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
diffs = {}
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
if args.device:
diff = diff.to(args.device)
diffs[lora_name] = diff
# make LoRA with svd
print("calculating by svd")
rank = args.dim
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
conv2d = (len(mat.size()) == 4)
if conv2d:
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
lora_weights[lora_name] = (U, Vh)
# make state dict for LoRA
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
lora_sd = lora_network_o.state_dict()
print(f"LoRA has {len(lora_sd)} weights.")
for key in list(lora_sd.keys()):
lora_name = key.split('.')[0]
i = 0 if "lora_up" in key else 1
weights = lora_weights[lora_name][i]
# print(key, i, weights.size(), lora_sd[key].size())
if len(lora_sd[key].size()) == 4:
weights = weights.unsqueeze(2).unsqueeze(3)
assert weights.size() == lora_sd[key].size()
lora_sd[key] = weights
# load state dict to LoRA and save it
info = lora_network_o.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(args.save_to)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
lora_network_o.save_weights(args.save_to, save_dtype)
print(f"LoRA weights are saved to: {args.save_to}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
parser.add_argument("--model_org", type=str, default=None,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
parser.add_argument("--model_tuned", type=str, default=None,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数デフォルト4")
parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う")
args = parser.parse_args()
svd(args)