Update to sd-script dev code base
This commit is contained in:
parent
2deddd5f3c
commit
fc5d2b2c31
@ -176,6 +176,9 @@ This will store your a backup file with your current locally installed pip packa
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
|
* 2023/03/10 (v21.2.1):
|
||||||
|
- Update to latest sd-script code
|
||||||
|
- Add support for SVD based LoRA merge
|
||||||
* 2023/03/09 (v21.2.0):
|
* 2023/03/09 (v21.2.0):
|
||||||
- Fix issue https://github.com/bmaltais/kohya_ss/issues/335
|
- Fix issue https://github.com/bmaltais/kohya_ss/issues/335
|
||||||
- Add option to print LoRA trainer command without executing it
|
- Add option to print LoRA trainer command without executing it
|
||||||
|
@ -148,11 +148,11 @@ def gradio_resize_lora_tab():
|
|||||||
value='fp16',
|
value='fp16',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
device = gr.Textbox(
|
device = gr.Dropdown(
|
||||||
label='Device',
|
label='Device',
|
||||||
placeholder='{Optional) device to use, cuda for GPU. Default: cuda',
|
choices=['cpu', 'cuda',],
|
||||||
interactive=True,
|
|
||||||
value='cuda',
|
value='cuda',
|
||||||
|
interactive=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_button = gr.Button('Resize model')
|
convert_button = gr.Button('Resize model')
|
||||||
|
187
library/svd_merge_lora_gui.py
Normal file
187
library/svd_merge_lora_gui.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from easygui import msgbox
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
from .common_gui import (
|
||||||
|
get_saveasfilename_path,
|
||||||
|
get_any_file_path,
|
||||||
|
get_file_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||||
|
|
||||||
|
|
||||||
|
def svd_merge_lora(
|
||||||
|
lora_a_model,
|
||||||
|
lora_b_model,
|
||||||
|
ratio,
|
||||||
|
save_to,
|
||||||
|
precision,
|
||||||
|
save_precision,
|
||||||
|
new_rank,
|
||||||
|
new_conv_rank,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
# Check for caption_text_input
|
||||||
|
if lora_a_model == '':
|
||||||
|
msgbox('Invalid model A file')
|
||||||
|
return
|
||||||
|
|
||||||
|
if lora_b_model == '':
|
||||||
|
msgbox('Invalid model B file')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if source model exist
|
||||||
|
if not os.path.isfile(lora_a_model):
|
||||||
|
msgbox('The provided model A is not a file')
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.isfile(lora_b_model):
|
||||||
|
msgbox('The provided model B is not a file')
|
||||||
|
return
|
||||||
|
|
||||||
|
ratio_a = ratio
|
||||||
|
ratio_b = 1 - ratio
|
||||||
|
|
||||||
|
run_cmd = f'{PYTHON} "{os.path.join("networks","svd_merge_lora.py")}"'
|
||||||
|
run_cmd += f' --save_precision {save_precision}'
|
||||||
|
run_cmd += f' --precision {precision}'
|
||||||
|
run_cmd += f' --save_to "{save_to}"'
|
||||||
|
run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"'
|
||||||
|
run_cmd += f' --ratios {ratio_a} {ratio_b}'
|
||||||
|
run_cmd += f' --device {device}'
|
||||||
|
run_cmd += f' --new_rank "{new_rank}"'
|
||||||
|
run_cmd += f' --new_conv_rank "{new_conv_rank}"'
|
||||||
|
|
||||||
|
print(run_cmd)
|
||||||
|
|
||||||
|
# Run the command
|
||||||
|
if os.name == 'posix':
|
||||||
|
os.system(run_cmd)
|
||||||
|
else:
|
||||||
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
|
# Gradio UI
|
||||||
|
###
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_svd_merge_lora_tab():
|
||||||
|
with gr.Tab('Merge LoRA (SVD)'):
|
||||||
|
gr.Markdown('This utility can merge two LoRA networks together.')
|
||||||
|
|
||||||
|
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||||
|
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
lora_a_model = gr.Textbox(
|
||||||
|
label='LoRA model "A"',
|
||||||
|
placeholder='Path to the LoRA A model',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
button_lora_a_model_file = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
button_lora_a_model_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||||
|
outputs=lora_a_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_b_model = gr.Textbox(
|
||||||
|
label='LoRA model "B"',
|
||||||
|
placeholder='Path to the LoRA B model',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
button_lora_b_model_file = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
button_lora_b_model_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||||
|
outputs=lora_b_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
ratio = gr.Slider(
|
||||||
|
label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
|
||||||
|
minimum=0,
|
||||||
|
maximum=1,
|
||||||
|
step=0.01,
|
||||||
|
value=0.5,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
new_rank = gr.Slider(
|
||||||
|
label='New Rank',
|
||||||
|
minimum=1,
|
||||||
|
maximum=1024,
|
||||||
|
step=1,
|
||||||
|
value=128,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
new_conv_rank = gr.Slider(
|
||||||
|
label='New Conv Rank',
|
||||||
|
minimum=1,
|
||||||
|
maximum=1024,
|
||||||
|
step=1,
|
||||||
|
value=128,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
save_to = gr.Textbox(
|
||||||
|
label='Save to',
|
||||||
|
placeholder='path for the file to save...',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
button_save_to = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
button_save_to.click(
|
||||||
|
get_saveasfilename_path,
|
||||||
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
|
outputs=save_to,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
precision = gr.Dropdown(
|
||||||
|
label='Merge precision',
|
||||||
|
choices=['fp16', 'bf16', 'float'],
|
||||||
|
value='float',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
save_precision = gr.Dropdown(
|
||||||
|
label='Save precision',
|
||||||
|
choices=['fp16', 'bf16', 'float'],
|
||||||
|
value='float',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
device = gr.Dropdown(
|
||||||
|
label='Device',
|
||||||
|
choices=['cpu', 'cuda',],
|
||||||
|
value='cuda',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_button = gr.Button('Merge model')
|
||||||
|
|
||||||
|
convert_button.click(
|
||||||
|
svd_merge_lora,
|
||||||
|
inputs=[
|
||||||
|
lora_a_model,
|
||||||
|
lora_b_model,
|
||||||
|
ratio,
|
||||||
|
save_to,
|
||||||
|
precision,
|
||||||
|
save_precision,
|
||||||
|
new_rank,
|
||||||
|
new_conv_rank,
|
||||||
|
device,
|
||||||
|
],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
@ -912,10 +912,14 @@ class FineTuningDataset(BaseDataset):
|
|||||||
if os.path.exists(image_key):
|
if os.path.exists(image_key):
|
||||||
abs_path = image_key
|
abs_path = image_key
|
||||||
else:
|
else:
|
||||||
# わりといい加減だがいい方法が思いつかん
|
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||||
abs_path = glob_images(subset.image_dir, image_key)
|
if os.path.exists(npz_path):
|
||||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
abs_path = npz_path
|
||||||
abs_path = abs_path[0]
|
else:
|
||||||
|
# わりといい加減だがいい方法が思いつかん
|
||||||
|
abs_path = glob_images(subset.image_dir, image_key)
|
||||||
|
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
||||||
|
abs_path = abs_path[0]
|
||||||
|
|
||||||
caption = img_md.get('caption')
|
caption = img_md.get('caption')
|
||||||
tags = img_md.get('tags')
|
tags = img_md.get('tags')
|
||||||
@ -1757,15 +1761,22 @@ def get_optimizer(args, trainable_params):
|
|||||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||||
|
|
||||||
min_lr = lr
|
actual_lr = lr
|
||||||
|
lr_count = 1
|
||||||
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
||||||
|
lrs = set()
|
||||||
|
actual_lr = trainable_params[0].get("lr", actual_lr)
|
||||||
for group in trainable_params:
|
for group in trainable_params:
|
||||||
min_lr = min(min_lr, group.get("lr", lr))
|
lrs.add(group.get("lr", actual_lr))
|
||||||
|
lr_count = len(lrs)
|
||||||
|
|
||||||
if min_lr <= 0.1:
|
if actual_lr <= 0.1:
|
||||||
print(
|
print(
|
||||||
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
|
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
|
||||||
print('recommend option: lr=1.0 / 推奨は1.0です')
|
print('recommend option: lr=1.0 / 推奨は1.0です')
|
||||||
|
if lr_count > 1:
|
||||||
|
print(
|
||||||
|
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}")
|
||||||
|
|
||||||
optimizer_class = dadaptation.DAdaptAdam
|
optimizer_class = dadaptation.DAdaptAdam
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
@ -2296,6 +2307,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
|
if not accelerator.is_main_process:
|
||||||
|
continue
|
||||||
prompt = prompt.strip()
|
prompt = prompt.strip()
|
||||||
if len(prompt) == 0 or prompt[0] == '#':
|
if len(prompt) == 0 or prompt[0] == '#':
|
||||||
continue
|
continue
|
||||||
@ -2355,6 +2368,12 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
|
|
||||||
height = max(64, height - height % 8) # round to divisible by 8
|
height = max(64, height - height % 8) # round to divisible by 8
|
||||||
width = max(64, width - width % 8) # round to divisible by 8
|
width = max(64, width - width % 8) # round to divisible by 8
|
||||||
|
print(f"prompt: {prompt}")
|
||||||
|
print(f"negative_prompt: {negative_prompt}")
|
||||||
|
print(f"height: {height}")
|
||||||
|
print(f"width: {width}")
|
||||||
|
print(f"sample_steps: {sample_steps}")
|
||||||
|
print(f"scale: {scale}")
|
||||||
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
||||||
|
|
||||||
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||||
|
@ -38,6 +38,7 @@ from library.tensorboard_gui import (
|
|||||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||||
from library.utilities import utilities_tab
|
from library.utilities import utilities_tab
|
||||||
from library.merge_lora_gui import gradio_merge_lora_tab
|
from library.merge_lora_gui import gradio_merge_lora_tab
|
||||||
|
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
|
||||||
from library.verify_lora_gui import gradio_verify_lora_tab
|
from library.verify_lora_gui import gradio_verify_lora_tab
|
||||||
from library.resize_lora_gui import gradio_resize_lora_tab
|
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||||
@ -879,6 +880,7 @@ def lora_tab(
|
|||||||
)
|
)
|
||||||
gradio_dataset_balancing_tab()
|
gradio_dataset_balancing_tab()
|
||||||
gradio_merge_lora_tab()
|
gradio_merge_lora_tab()
|
||||||
|
gradio_svd_merge_lora_tab()
|
||||||
gradio_resize_lora_tab()
|
gradio_resize_lora_tab()
|
||||||
gradio_verify_lora_tab()
|
gradio_verify_lora_tab()
|
||||||
|
|
||||||
|
@ -103,7 +103,8 @@ def svd(args):
|
|||||||
|
|
||||||
if args.device:
|
if args.device:
|
||||||
mat = mat.to(args.device)
|
mat = mat.to(args.device)
|
||||||
# print(mat.size(), mat.device, rank, in_dim, out_dim)
|
|
||||||
|
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
@ -112,7 +113,7 @@ def svd(args):
|
|||||||
else:
|
else:
|
||||||
mat = mat.squeeze()
|
mat = mat.squeeze()
|
||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(mat)
|
U, S, Vh = torch.linalg.svd(mat.to("cuda"))
|
||||||
|
|
||||||
U = U[:, :rank]
|
U = U[:, :rank]
|
||||||
S = S[:rank]
|
S = S[:rank]
|
||||||
@ -137,27 +138,17 @@ def svd(args):
|
|||||||
lora_weights[lora_name] = (U, Vh)
|
lora_weights[lora_name] = (U, Vh)
|
||||||
|
|
||||||
# make state dict for LoRA
|
# make state dict for LoRA
|
||||||
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
lora_sd = {}
|
||||||
lora_sd = lora_network_o.state_dict()
|
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||||
print(f"LoRA has {len(lora_sd)} weights.")
|
lora_sd[lora_name + '.lora_up.weight'] = up_weight
|
||||||
|
lora_sd[lora_name + '.lora_down.weight'] = down_weight
|
||||||
for key in list(lora_sd.keys()):
|
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
||||||
if "alpha" in key:
|
|
||||||
continue
|
|
||||||
|
|
||||||
lora_name = key.split('.')[0]
|
|
||||||
i = 0 if "lora_up" in key else 1
|
|
||||||
|
|
||||||
weights = lora_weights[lora_name][i]
|
|
||||||
# print(key, i, weights.size(), lora_sd[key].size())
|
|
||||||
# if len(lora_sd[key].size()) == 4:
|
|
||||||
# weights = weights.unsqueeze(2).unsqueeze(3)
|
|
||||||
|
|
||||||
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
|
||||||
lora_sd[key] = weights
|
|
||||||
|
|
||||||
# load state dict to LoRA and save it
|
# load state dict to LoRA and save it
|
||||||
info = lora_network_o.load_state_dict(lora_sd)
|
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
||||||
|
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
||||||
|
|
||||||
|
info = lora_network_save.load_state_dict(lora_sd)
|
||||||
print(f"Loading extracted LoRA weights: {info}")
|
print(f"Loading extracted LoRA weights: {info}")
|
||||||
|
|
||||||
dir_name = os.path.dirname(args.save_to)
|
dir_name = os.path.dirname(args.save_to)
|
||||||
@ -167,7 +158,7 @@ def svd(args):
|
|||||||
# minimum metadata
|
# minimum metadata
|
||||||
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||||
|
|
||||||
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
||||||
print(f"LoRA weights are saved to: {args.save_to}")
|
print(f"LoRA weights are saved to: {args.save_to}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,30 +21,34 @@ class LoRAModule(torch.nn.Module):
|
|||||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lora_name = lora_name
|
self.lora_name = lora_name
|
||||||
self.lora_dim = lora_dim
|
|
||||||
|
|
||||||
if org_module.__class__.__name__ == 'Conv2d':
|
if org_module.__class__.__name__ == 'Conv2d':
|
||||||
in_dim = org_module.in_channels
|
in_dim = org_module.in_channels
|
||||||
out_dim = org_module.out_channels
|
out_dim = org_module.out_channels
|
||||||
|
else:
|
||||||
|
in_dim = org_module.in_features
|
||||||
|
out_dim = org_module.out_features
|
||||||
|
|
||||||
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
|
# if limit_rank:
|
||||||
if self.lora_dim != lora_dim:
|
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
||||||
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
# if self.lora_dim != lora_dim:
|
||||||
|
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||||
|
# else:
|
||||||
|
self.lora_dim = lora_dim
|
||||||
|
|
||||||
|
if org_module.__class__.__name__ == 'Conv2d':
|
||||||
kernel_size = org_module.kernel_size
|
kernel_size = org_module.kernel_size
|
||||||
stride = org_module.stride
|
stride = org_module.stride
|
||||||
padding = org_module.padding
|
padding = org_module.padding
|
||||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||||
else:
|
else:
|
||||||
in_dim = org_module.in_features
|
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||||
out_dim = org_module.out_features
|
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||||
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
|
||||||
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
|
||||||
|
|
||||||
if type(alpha) == torch.Tensor:
|
if type(alpha) == torch.Tensor:
|
||||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||||
self.scale = alpha / self.lora_dim
|
self.scale = alpha / self.lora_dim
|
||||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||||
|
|
||||||
@ -149,12 +153,13 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if weights_sd is None:
|
||||||
from safetensors.torch import load_file, safe_open
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
weights_sd = load_file(file)
|
from safetensors.torch import load_file, safe_open
|
||||||
else:
|
weights_sd = load_file(file)
|
||||||
weights_sd = torch.load(file, map_location='cpu')
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location='cpu')
|
||||||
|
|
||||||
# get dim/alpha mapping
|
# get dim/alpha mapping
|
||||||
modules_dim = {}
|
modules_dim = {}
|
||||||
@ -174,7 +179,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa
|
|||||||
# support old LoRA without alpha
|
# support old LoRA without alpha
|
||||||
for key in modules_dim.keys():
|
for key in modules_dim.keys():
|
||||||
if key not in modules_alpha:
|
if key not in modules_alpha:
|
||||||
modules_alpha = modules_dim[key]
|
modules_alpha = modules_dim[key]
|
||||||
|
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||||
network.weights_sd = weights_sd
|
network.weights_sd = weights_sd
|
||||||
@ -183,7 +188,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa
|
|||||||
|
|
||||||
class LoRANetwork(torch.nn.Module):
|
class LoRANetwork(torch.nn.Module):
|
||||||
# is it possible to apply conv_in and conv_out?
|
# is it possible to apply conv_in and conv_out?
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||||
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_UNET = 'lora_unet'
|
LORA_PREFIX_UNET = 'lora_unet'
|
||||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||||
@ -245,7 +251,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
|
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||||
|
if modules_dim is not None or self.conv_lora_dim is not None:
|
||||||
|
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
|
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
||||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
self.weights_sd = None
|
self.weights_sd = None
|
||||||
@ -371,7 +382,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
@staticmethod
|
@ staticmethod
|
||||||
def set_regions(networks, image):
|
def set_regions(networks, image):
|
||||||
image = image.astype(np.float32) / 255.0
|
image = image.astype(np.float32) / 255.0
|
||||||
for i, network in enumerate(networks[:3]):
|
for i, network in enumerate(networks[:3]):
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
||||||
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||||
# Thanks to cloneofsimo and kohya
|
# Thanks to cloneofsimo
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file, safe_open
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from library import train_util, model_util
|
from library import train_util, model_util
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
MIN_SV = 1e-6
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
if model_util.is_safetensors(file_name):
|
if model_util.is_safetensors(file_name):
|
||||||
@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
def index_sv_cumulative(S, target):
|
||||||
|
original_sum = float(torch.sum(S))
|
||||||
|
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
||||||
|
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||||
|
if index >= len(S):
|
||||||
|
index = len(S) - 1
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def index_sv_fro(S, target):
|
||||||
|
S_squared = S.pow(2)
|
||||||
|
s_fro_sq = float(torch.sum(S_squared))
|
||||||
|
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
||||||
|
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||||
|
if index >= len(S):
|
||||||
|
index = len(S) - 1
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||||
|
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||||
|
out_size, in_size, kernel_size, _ = weight.size()
|
||||||
|
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||||
|
|
||||||
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||||
|
lora_rank = param_dict["new_rank"]
|
||||||
|
|
||||||
|
U = U[:, :lora_rank]
|
||||||
|
S = S[:lora_rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:lora_rank, :]
|
||||||
|
|
||||||
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
||||||
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
||||||
|
del U, S, Vh, weight
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
|
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||||
|
out_size, in_size = weight.size()
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||||
|
|
||||||
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||||
|
lora_rank = param_dict["new_rank"]
|
||||||
|
|
||||||
|
U = U[:, :lora_rank]
|
||||||
|
S = S[:lora_rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:lora_rank, :]
|
||||||
|
|
||||||
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||||
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||||
|
del U, S, Vh, weight
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
|
def merge_conv(lora_down, lora_up, device):
|
||||||
|
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||||
|
out_size, out_rank, _, _ = lora_up.shape
|
||||||
|
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||||
|
|
||||||
|
lora_down = lora_down.to(device)
|
||||||
|
lora_up = lora_up.to(device)
|
||||||
|
|
||||||
|
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
||||||
|
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
||||||
|
del lora_up, lora_down
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def merge_linear(lora_down, lora_up, device):
|
||||||
|
in_rank, in_size = lora_down.shape
|
||||||
|
out_size, out_rank = lora_up.shape
|
||||||
|
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||||
|
|
||||||
|
lora_down = lora_down.to(device)
|
||||||
|
lora_up = lora_up.to(device)
|
||||||
|
|
||||||
|
weight = lora_up @ lora_down
|
||||||
|
del lora_up, lora_down
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||||
|
param_dict = {}
|
||||||
|
|
||||||
|
if dynamic_method=="sv_ratio":
|
||||||
|
# Calculate new dim and alpha based off ratio
|
||||||
|
max_sv = S[0]
|
||||||
|
min_sv = max_sv/dynamic_param
|
||||||
|
new_rank = max(torch.sum(S > min_sv).item(),1)
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
elif dynamic_method=="sv_cumulative":
|
||||||
|
# Calculate new dim and alpha based off cumulative sum
|
||||||
|
new_rank = index_sv_cumulative(S, dynamic_param)
|
||||||
|
new_rank = max(new_rank, 1)
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
elif dynamic_method=="sv_fro":
|
||||||
|
# Calculate new dim and alpha based off sqrt sum of squares
|
||||||
|
new_rank = index_sv_fro(S, dynamic_param)
|
||||||
|
new_rank = min(max(new_rank, 1), len(S)-1)
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
else:
|
||||||
|
new_rank = rank
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
|
||||||
|
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||||
|
new_rank = 1
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
elif new_rank > rank: # cap max rank at rank
|
||||||
|
new_rank = rank
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate resize info
|
||||||
|
s_sum = torch.sum(torch.abs(S))
|
||||||
|
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||||
|
|
||||||
|
S_squared = S.pow(2)
|
||||||
|
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||||
|
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||||
|
fro_percent = float(s_red_fro/s_fro)
|
||||||
|
|
||||||
|
param_dict["new_rank"] = new_rank
|
||||||
|
param_dict["new_alpha"] = new_alpha
|
||||||
|
param_dict["sum_retained"] = (s_rank)/s_sum
|
||||||
|
param_dict["fro_retained"] = fro_percent
|
||||||
|
param_dict["max_ratio"] = S[0]/S[new_rank]
|
||||||
|
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
|
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||||
network_alpha = None
|
network_alpha = None
|
||||||
network_dim = None
|
network_dim = None
|
||||||
verbose_str = "\n"
|
verbose_str = "\n"
|
||||||
|
fro_list = []
|
||||||
CLAMP_QUANTILE = 0.99
|
|
||||||
|
|
||||||
# Extract loaded lora dim and alpha
|
# Extract loaded lora dim and alpha
|
||||||
for key, value in lora_sd.items():
|
for key, value in lora_sd.items():
|
||||||
@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|||||||
network_alpha = network_dim
|
network_alpha = network_dim
|
||||||
|
|
||||||
scale = network_alpha/network_dim
|
scale = network_alpha/network_dim
|
||||||
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
|
||||||
|
|
||||||
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
|
if dynamic_method:
|
||||||
|
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
||||||
|
|
||||||
lora_down_weight = None
|
lora_down_weight = None
|
||||||
lora_up_weight = None
|
lora_up_weight = None
|
||||||
@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|||||||
block_down_name = None
|
block_down_name = None
|
||||||
block_up_name = None
|
block_up_name = None
|
||||||
|
|
||||||
print("resizing lora...")
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for key, value in tqdm(lora_sd.items()):
|
for key, value in tqdm(lora_sd.items()):
|
||||||
if 'lora_down' in key:
|
if 'lora_down' in key:
|
||||||
@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|||||||
conv2d = (len(lora_down_weight.size()) == 4)
|
conv2d = (len(lora_down_weight.size()) == 4)
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
lora_down_weight = lora_down_weight.squeeze()
|
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||||
lora_up_weight = lora_up_weight.squeeze()
|
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||||
|
else:
|
||||||
if device:
|
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||||
org_device = lora_up_weight.device
|
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||||
lora_up_weight = lora_up_weight.to(args.device)
|
|
||||||
lora_down_weight = lora_down_weight.to(args.device)
|
|
||||||
|
|
||||||
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
|
||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
s_sum = torch.sum(torch.abs(S))
|
max_ratio = param_dict['max_ratio']
|
||||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
sum_retained = param_dict['sum_retained']
|
||||||
verbose_str+=f"{block_down_name:76} | "
|
fro_retained = param_dict['fro_retained']
|
||||||
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
|
if not np.isnan(fro_retained):
|
||||||
|
fro_list.append(float(fro_retained))
|
||||||
|
|
||||||
U = U[:, :new_rank]
|
verbose_str+=f"{block_down_name:75} | "
|
||||||
S = S[:new_rank]
|
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||||
U = U @ torch.diag(S)
|
|
||||||
|
|
||||||
Vh = Vh[:new_rank, :]
|
if verbose and dynamic_method:
|
||||||
|
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||||
|
else:
|
||||||
|
verbose_str+=f"\n"
|
||||||
|
|
||||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
new_alpha = param_dict['new_alpha']
|
||||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||||
low_val = -hi_val
|
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||||
|
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
||||||
U = U.clamp(low_val, hi_val)
|
|
||||||
Vh = Vh.clamp(low_val, hi_val)
|
|
||||||
|
|
||||||
if conv2d:
|
|
||||||
U = U.unsqueeze(2).unsqueeze(3)
|
|
||||||
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
|
||||||
|
|
||||||
if device:
|
|
||||||
U = U.to(org_device)
|
|
||||||
Vh = Vh.to(org_device)
|
|
||||||
|
|
||||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
|
||||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
|
||||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
|
||||||
|
|
||||||
block_down_name = None
|
block_down_name = None
|
||||||
block_up_name = None
|
block_up_name = None
|
||||||
lora_down_weight = None
|
lora_down_weight = None
|
||||||
lora_up_weight = None
|
lora_up_weight = None
|
||||||
weights_loaded = False
|
weights_loaded = False
|
||||||
|
del param_dict
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(verbose_str)
|
print(verbose_str)
|
||||||
|
|
||||||
|
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||||
print("resizing complete")
|
print("resizing complete")
|
||||||
return o_lora_sd, network_dim, new_alpha
|
return o_lora_sd, network_dim, new_alpha
|
||||||
|
|
||||||
@ -151,6 +274,9 @@ def resize(args):
|
|||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if args.dynamic_method and not args.dynamic_param:
|
||||||
|
raise Exception("If using dynamic_method, then dynamic_param is required")
|
||||||
|
|
||||||
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
||||||
save_dtype = str_to_dtype(args.save_precision)
|
save_dtype = str_to_dtype(args.save_precision)
|
||||||
if save_dtype is None:
|
if save_dtype is None:
|
||||||
@ -159,17 +285,23 @@ def resize(args):
|
|||||||
print("loading Model...")
|
print("loading Model...")
|
||||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||||
|
|
||||||
print("resizing rank...")
|
print("Resizing Lora...")
|
||||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
||||||
|
|
||||||
# update metadata
|
# update metadata
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
comment = metadata.get("ss_training_comment", "")
|
comment = metadata.get("ss_training_comment", "")
|
||||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
|
||||||
metadata["ss_network_dim"] = str(args.new_rank)
|
if not args.dynamic_method:
|
||||||
metadata["ss_network_alpha"] = str(new_alpha)
|
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||||
|
metadata["ss_network_dim"] = str(args.new_rank)
|
||||||
|
metadata["ss_network_alpha"] = str(new_alpha)
|
||||||
|
else:
|
||||||
|
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
||||||
|
metadata["ss_network_dim"] = 'Dynamic'
|
||||||
|
metadata["ss_network_alpha"] = 'Dynamic'
|
||||||
|
|
||||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
metadata["sshs_model_hash"] = model_hash
|
metadata["sshs_model_hash"] = model_hash
|
||||||
@ -193,6 +325,11 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
parser.add_argument("--verbose", action="store_true",
|
parser.add_argument("--verbose", action="store_true",
|
||||||
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
||||||
|
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
||||||
|
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
||||||
|
parser.add_argument("--dynamic_param", type=float, default=None,
|
||||||
|
help="Specify target for dynamic reduction")
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
resize(args)
|
resize(args)
|
||||||
|
@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype):
|
|||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, state_dict, dtype):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
if type(state_dict[key]) == torch.Tensor:
|
if type(state_dict[key]) == torch.Tensor:
|
||||||
state_dict[key] = state_dict[key].to(dtype)
|
state_dict[key] = state_dict[key].to(dtype)
|
||||||
|
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||||
save_file(model, file_name)
|
save_file(state_dict, file_name)
|
||||||
else:
|
else:
|
||||||
torch.save(model, file_name)
|
torch.save(state_dict, file_name)
|
||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||||
@ -76,7 +76,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
down_weight = down_weight.to(device)
|
down_weight = down_weight.to(device)
|
||||||
|
|
||||||
# W <- W + U * D
|
# W <- W + U * D
|
||||||
scale = (alpha / network_dim)
|
scale = (alpha / network_dim).to(device)
|
||||||
if not conv2d: # linear
|
if not conv2d: # linear
|
||||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||||
elif kernel_size == (1, 1):
|
elif kernel_size == (1, 1):
|
||||||
@ -105,6 +105,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
mat = mat.squeeze()
|
mat = mat.squeeze()
|
||||||
|
|
||||||
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
||||||
|
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(mat)
|
U, S, Vh = torch.linalg.svd(mat)
|
||||||
|
|
||||||
@ -114,12 +115,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
|
|
||||||
Vh = Vh[:module_new_rank, :]
|
Vh = Vh[:module_new_rank, :]
|
||||||
|
|
||||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
# dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
low_val = -hi_val
|
# low_val = -hi_val
|
||||||
|
|
||||||
U = U.clamp(low_val, hi_val)
|
# U = U.clamp(low_val, hi_val)
|
||||||
Vh = Vh.clamp(low_val, hi_val)
|
# Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
||||||
@ -156,7 +157,7 @@ def merge(args):
|
|||||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
print(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
save_to_file(args.save_to, state_dict, save_dtype)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -502,6 +502,14 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
|||||||
|
|
||||||
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
|
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
|
||||||
|
|
||||||
|
- `--persistent_data_loader_workers`
|
||||||
|
|
||||||
|
Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
|
||||||
|
|
||||||
|
- `--max_data_loader_n_workers`
|
||||||
|
|
||||||
|
データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
|
||||||
|
|
||||||
- `--logging_dir` / `--log_prefix`
|
- `--logging_dir` / `--log_prefix`
|
||||||
|
|
||||||
学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
|
学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
|
||||||
|
@ -106,6 +106,7 @@ def train(args):
|
|||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
print("prepare accelerator")
|
||||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||||
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
@ -134,6 +135,8 @@ def train(args):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# prepare network
|
# prepare network
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.dirname(__file__))
|
||||||
print("import network module:", args.network_module)
|
print("import network module:", args.network_module)
|
||||||
network_module = importlib.import_module(args.network_module)
|
network_module = importlib.import_module(args.network_module)
|
||||||
|
|
||||||
@ -175,12 +178,13 @@ def train(args):
|
|||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
if is_main_process:
|
||||||
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
|
||||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
@ -251,15 +255,17 @@ def train(args):
|
|||||||
# 学習する
|
# 学習する
|
||||||
# TODO: find a way to handle total batch size when there are multiple datasets
|
# TODO: find a way to handle total batch size when there are multiple datasets
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
print("running training / 学習開始")
|
|
||||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
if is_main_process:
|
||||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
print("running training / 学習開始")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
|
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# TODO refactor metadata creation and move to util
|
# TODO refactor metadata creation and move to util
|
||||||
metadata = {
|
metadata = {
|
||||||
@ -471,7 +477,8 @@ def train(args):
|
|||||||
loss_list = []
|
loss_list = []
|
||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
if is_main_process:
|
||||||
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
train_dataset_group.set_current_epoch(epoch + 1)
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch+1)
|
metadata["ss_epoch"] = str(epoch+1)
|
||||||
@ -583,9 +590,10 @@ def train(args):
|
|||||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||||
os.remove(old_ckpt_file)
|
os.remove(old_ckpt_file)
|
||||||
|
|
||||||
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
if is_main_process:
|
||||||
if saving and args.save_state:
|
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
if saving and args.save_state:
|
||||||
|
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||||
|
|
||||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||||
|
|
||||||
@ -594,7 +602,6 @@ def train(args):
|
|||||||
metadata["ss_epoch"] = str(num_train_epochs)
|
metadata["ss_epoch"] = str(num_train_epochs)
|
||||||
metadata["ss_training_finished_at"] = str(time.time())
|
metadata["ss_training_finished_at"] = str(time.time())
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
network = unwrap_model(network)
|
network = unwrap_model(network)
|
||||||
|
|
||||||
|
@ -64,6 +64,10 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
|||||||
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
||||||
* `--network_alpha`
|
* `--network_alpha`
|
||||||
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
||||||
|
* `--persistent_data_loader_workers`
|
||||||
|
* Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
|
||||||
|
* `--max_data_loader_n_workers`
|
||||||
|
* データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
|
||||||
* `--network_weights`
|
* `--network_weights`
|
||||||
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
||||||
* `--network_train_unet_only`
|
* `--network_train_unet_only`
|
||||||
|
Loading…
Reference in New Issue
Block a user