Update to sd-script dev code base
This commit is contained in:
parent
2deddd5f3c
commit
fc5d2b2c31
@ -176,6 +176,9 @@ This will store your a backup file with your current locally installed pip packa
|
||||
|
||||
## Change History
|
||||
|
||||
* 2023/03/10 (v21.2.1):
|
||||
- Update to latest sd-script code
|
||||
- Add support for SVD based LoRA merge
|
||||
* 2023/03/09 (v21.2.0):
|
||||
- Fix issue https://github.com/bmaltais/kohya_ss/issues/335
|
||||
- Add option to print LoRA trainer command without executing it
|
||||
|
@ -148,11 +148,11 @@ def gradio_resize_lora_tab():
|
||||
value='fp16',
|
||||
interactive=True,
|
||||
)
|
||||
device = gr.Textbox(
|
||||
device = gr.Dropdown(
|
||||
label='Device',
|
||||
placeholder='{Optional) device to use, cuda for GPU. Default: cuda',
|
||||
interactive=True,
|
||||
choices=['cpu', 'cuda',],
|
||||
value='cuda',
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
convert_button = gr.Button('Resize model')
|
||||
|
187
library/svd_merge_lora_gui.py
Normal file
187
library/svd_merge_lora_gui.py
Normal file
@ -0,0 +1,187 @@
|
||||
import gradio as gr
|
||||
from easygui import msgbox
|
||||
import subprocess
|
||||
import os
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_any_file_path,
|
||||
get_file_path,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
|
||||
|
||||
def svd_merge_lora(
|
||||
lora_a_model,
|
||||
lora_b_model,
|
||||
ratio,
|
||||
save_to,
|
||||
precision,
|
||||
save_precision,
|
||||
new_rank,
|
||||
new_conv_rank,
|
||||
device,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if lora_a_model == '':
|
||||
msgbox('Invalid model A file')
|
||||
return
|
||||
|
||||
if lora_b_model == '':
|
||||
msgbox('Invalid model B file')
|
||||
return
|
||||
|
||||
# Check if source model exist
|
||||
if not os.path.isfile(lora_a_model):
|
||||
msgbox('The provided model A is not a file')
|
||||
return
|
||||
|
||||
if not os.path.isfile(lora_b_model):
|
||||
msgbox('The provided model B is not a file')
|
||||
return
|
||||
|
||||
ratio_a = ratio
|
||||
ratio_b = 1 - ratio
|
||||
|
||||
run_cmd = f'{PYTHON} "{os.path.join("networks","svd_merge_lora.py")}"'
|
||||
run_cmd += f' --save_precision {save_precision}'
|
||||
run_cmd += f' --precision {precision}'
|
||||
run_cmd += f' --save_to "{save_to}"'
|
||||
run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"'
|
||||
run_cmd += f' --ratios {ratio_a} {ratio_b}'
|
||||
run_cmd += f' --device {device}'
|
||||
run_cmd += f' --new_rank "{new_rank}"'
|
||||
run_cmd += f' --new_conv_rank "{new_conv_rank}"'
|
||||
|
||||
print(run_cmd)
|
||||
|
||||
# Run the command
|
||||
if os.name == 'posix':
|
||||
os.system(run_cmd)
|
||||
else:
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
|
||||
###
|
||||
# Gradio UI
|
||||
###
|
||||
|
||||
|
||||
def gradio_svd_merge_lora_tab():
|
||||
with gr.Tab('Merge LoRA (SVD)'):
|
||||
gr.Markdown('This utility can merge two LoRA networks together.')
|
||||
|
||||
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
|
||||
with gr.Row():
|
||||
lora_a_model = gr.Textbox(
|
||||
label='LoRA model "A"',
|
||||
placeholder='Path to the LoRA A model',
|
||||
interactive=True,
|
||||
)
|
||||
button_lora_a_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_a_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
lora_b_model = gr.Textbox(
|
||||
label='LoRA model "B"',
|
||||
placeholder='Path to the LoRA B model',
|
||||
interactive=True,
|
||||
)
|
||||
button_lora_b_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_b_model,
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Row():
|
||||
ratio = gr.Slider(
|
||||
label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
value=0.5,
|
||||
interactive=True,
|
||||
)
|
||||
new_rank = gr.Slider(
|
||||
label='New Rank',
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
step=1,
|
||||
value=128,
|
||||
interactive=True,
|
||||
)
|
||||
new_conv_rank = gr.Slider(
|
||||
label='New Conv Rank',
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
step=1,
|
||||
value=128,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
save_to = gr.Textbox(
|
||||
label='Save to',
|
||||
placeholder='path for the file to save...',
|
||||
interactive=True,
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path,
|
||||
inputs=[save_to, lora_ext, lora_ext_name],
|
||||
outputs=save_to,
|
||||
show_progress=False,
|
||||
)
|
||||
precision = gr.Dropdown(
|
||||
label='Merge precision',
|
||||
choices=['fp16', 'bf16', 'float'],
|
||||
value='float',
|
||||
interactive=True,
|
||||
)
|
||||
save_precision = gr.Dropdown(
|
||||
label='Save precision',
|
||||
choices=['fp16', 'bf16', 'float'],
|
||||
value='float',
|
||||
interactive=True,
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label='Device',
|
||||
choices=['cpu', 'cuda',],
|
||||
value='cuda',
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
convert_button = gr.Button('Merge model')
|
||||
|
||||
convert_button.click(
|
||||
svd_merge_lora,
|
||||
inputs=[
|
||||
lora_a_model,
|
||||
lora_b_model,
|
||||
ratio,
|
||||
save_to,
|
||||
precision,
|
||||
save_precision,
|
||||
new_rank,
|
||||
new_conv_rank,
|
||||
device,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
@ -911,6 +911,10 @@ class FineTuningDataset(BaseDataset):
|
||||
# path情報を作る
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
abs_path = glob_images(subset.image_dir, image_key)
|
||||
@ -1757,15 +1761,22 @@ def get_optimizer(args, trainable_params):
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
|
||||
min_lr = lr
|
||||
actual_lr = lr
|
||||
lr_count = 1
|
||||
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
||||
lrs = set()
|
||||
actual_lr = trainable_params[0].get("lr", actual_lr)
|
||||
for group in trainable_params:
|
||||
min_lr = min(min_lr, group.get("lr", lr))
|
||||
lrs.add(group.get("lr", actual_lr))
|
||||
lr_count = len(lrs)
|
||||
|
||||
if min_lr <= 0.1:
|
||||
if actual_lr <= 0.1:
|
||||
print(
|
||||
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
|
||||
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
|
||||
print('recommend option: lr=1.0 / 推奨は1.0です')
|
||||
if lr_count > 1:
|
||||
print(
|
||||
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}")
|
||||
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
@ -2296,6 +2307,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
||||
with torch.no_grad():
|
||||
with accelerator.autocast():
|
||||
for i, prompt in enumerate(prompts):
|
||||
if not accelerator.is_main_process:
|
||||
continue
|
||||
prompt = prompt.strip()
|
||||
if len(prompt) == 0 or prompt[0] == '#':
|
||||
continue
|
||||
@ -2355,6 +2368,12 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
||||
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
print(f"prompt: {prompt}")
|
||||
print(f"negative_prompt: {negative_prompt}")
|
||||
print(f"height: {height}")
|
||||
print(f"width: {width}")
|
||||
print(f"sample_steps: {sample_steps}")
|
||||
print(f"scale: {scale}")
|
||||
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
||||
|
||||
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||
|
@ -38,6 +38,7 @@ from library.tensorboard_gui import (
|
||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||
from library.utilities import utilities_tab
|
||||
from library.merge_lora_gui import gradio_merge_lora_tab
|
||||
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
|
||||
from library.verify_lora_gui import gradio_verify_lora_tab
|
||||
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||
@ -879,6 +880,7 @@ def lora_tab(
|
||||
)
|
||||
gradio_dataset_balancing_tab()
|
||||
gradio_merge_lora_tab()
|
||||
gradio_svd_merge_lora_tab()
|
||||
gradio_resize_lora_tab()
|
||||
gradio_verify_lora_tab()
|
||||
|
||||
|
@ -103,7 +103,8 @@ def svd(args):
|
||||
|
||||
if args.device:
|
||||
mat = mat.to(args.device)
|
||||
# print(mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
|
||||
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
if conv2d:
|
||||
@ -112,7 +113,7 @@ def svd(args):
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
U, S, Vh = torch.linalg.svd(mat.to("cuda"))
|
||||
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
@ -137,27 +138,17 @@ def svd(args):
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
# make state dict for LoRA
|
||||
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
||||
lora_sd = lora_network_o.state_dict()
|
||||
print(f"LoRA has {len(lora_sd)} weights.")
|
||||
|
||||
for key in list(lora_sd.keys()):
|
||||
if "alpha" in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split('.')[0]
|
||||
i = 0 if "lora_up" in key else 1
|
||||
|
||||
weights = lora_weights[lora_name][i]
|
||||
# print(key, i, weights.size(), lora_sd[key].size())
|
||||
# if len(lora_sd[key].size()) == 4:
|
||||
# weights = weights.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
||||
lora_sd[key] = weights
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + '.lora_up.weight'] = up_weight
|
||||
lora_sd[lora_name + '.lora_down.weight'] = down_weight
|
||||
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
||||
|
||||
# load state dict to LoRA and save it
|
||||
info = lora_network_o.load_state_dict(lora_sd)
|
||||
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
||||
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
print(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
dir_name = os.path.dirname(args.save_to)
|
||||
@ -167,7 +158,7 @@ def svd(args):
|
||||
# minimum metadata
|
||||
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||
|
||||
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
||||
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
||||
print(f"LoRA weights are saved to: {args.save_to}")
|
||||
|
||||
|
||||
|
@ -21,30 +21,34 @@ class LoRAModule(torch.nn.Module):
|
||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == 'Conv2d':
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
|
||||
if self.lora_dim != lora_dim:
|
||||
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# if limit_rank:
|
||||
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
||||
# if self.lora_dim != lora_dim:
|
||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# else:
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == 'Conv2d':
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
@ -149,7 +153,8 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
from safetensors.torch import load_file, safe_open
|
||||
weights_sd = load_file(file)
|
||||
@ -183,7 +188,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
# is it possible to apply conv_in and conv_out?
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = 'lora_unet'
|
||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||
@ -245,7 +251,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
if modules_dim is not None or self.conv_lora_dim is not None:
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
self.weights_sd = None
|
||||
|
@ -1,14 +1,15 @@
|
||||
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
||||
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo and kohya
|
||||
# Thanks to cloneofsimo
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
import numpy as np
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if model_util.is_safetensors(file_name):
|
||||
@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
if index >= len(S):
|
||||
index = len(S) - 1
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
s_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
if index >= len(S):
|
||||
index = len(S) - 1
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def merge_conv(lora_down, lora_up, device):
|
||||
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||
out_size, out_rank, _, _ = lora_up.shape
|
||||
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
||||
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
def merge_linear(lora_down, lora_up, device):
|
||||
in_rank, in_size = lora_down.shape
|
||||
out_size, out_rank = lora_up.shape
|
||||
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
weight = lora_up @ lora_down
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
if dynamic_method=="sv_ratio":
|
||||
# Calculate new dim and alpha based off ratio
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv/dynamic_param
|
||||
new_rank = max(torch.sum(S > min_sv).item(),1)
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method=="sv_cumulative":
|
||||
# Calculate new dim and alpha based off cumulative sum
|
||||
new_rank = index_sv_cumulative(S, dynamic_param)
|
||||
new_rank = max(new_rank, 1)
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method=="sv_fro":
|
||||
# Calculate new dim and alpha based off sqrt sum of squares
|
||||
new_rank = index_sv_fro(S, dynamic_param)
|
||||
new_rank = min(max(new_rank, 1), len(S)-1)
|
||||
new_alpha = float(scale*new_rank)
|
||||
else:
|
||||
new_rank = rank
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
|
||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||
new_rank = 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
elif new_rank > rank: # cap max rank at rank
|
||||
new_rank = rank
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
|
||||
# Calculate resize info
|
||||
s_sum = torch.sum(torch.abs(S))
|
||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||
|
||||
S_squared = S.pow(2)
|
||||
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||
fro_percent = float(s_red_fro/s_fro)
|
||||
|
||||
param_dict["new_rank"] = new_rank
|
||||
param_dict["new_alpha"] = new_alpha
|
||||
param_dict["sum_retained"] = (s_rank)/s_sum
|
||||
param_dict["fro_retained"] = fro_percent
|
||||
param_dict["max_ratio"] = S[0]/S[new_rank]
|
||||
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
verbose_str = "\n"
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
fro_list = []
|
||||
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||
network_alpha = network_dim
|
||||
|
||||
scale = network_alpha/network_dim
|
||||
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
||||
|
||||
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
|
||||
if dynamic_method:
|
||||
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
||||
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
|
||||
print("resizing lora...")
|
||||
with torch.no_grad():
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
if 'lora_down' in key:
|
||||
@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||
conv2d = (len(lora_down_weight.size()) == 4)
|
||||
|
||||
if conv2d:
|
||||
lora_down_weight = lora_down_weight.squeeze()
|
||||
lora_up_weight = lora_up_weight.squeeze()
|
||||
|
||||
if device:
|
||||
org_device = lora_up_weight.device
|
||||
lora_up_weight = lora_up_weight.to(args.device)
|
||||
lora_down_weight = lora_down_weight.to(args.device)
|
||||
|
||||
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
||||
|
||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
|
||||
if verbose:
|
||||
s_sum = torch.sum(torch.abs(S))
|
||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||
verbose_str+=f"{block_down_name:76} | "
|
||||
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
|
||||
max_ratio = param_dict['max_ratio']
|
||||
sum_retained = param_dict['sum_retained']
|
||||
fro_retained = param_dict['fro_retained']
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
|
||||
U = U[:, :new_rank]
|
||||
S = S[:new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
verbose_str+=f"{block_down_name:75} | "
|
||||
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
|
||||
Vh = Vh[:new_rank, :]
|
||||
if verbose and dynamic_method:
|
||||
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str+=f"\n"
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.unsqueeze(2).unsqueeze(3)
|
||||
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
if device:
|
||||
U = U.to(org_device)
|
||||
Vh = Vh.to(org_device)
|
||||
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
||||
new_alpha = param_dict['new_alpha']
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
||||
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
weights_loaded = False
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
print("resizing complete")
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
|
||||
@ -151,6 +274,9 @@ def resize(args):
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
if args.dynamic_method and not args.dynamic_param:
|
||||
raise Exception("If using dynamic_method, then dynamic_param is required")
|
||||
|
||||
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
@ -159,17 +285,23 @@ def resize(args):
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
|
||||
print("resizing rank...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
||||
print("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
||||
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
|
||||
if not args.dynamic_method:
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||
metadata["ss_network_dim"] = str(args.new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
else:
|
||||
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
||||
metadata["ss_network_dim"] = 'Dynamic'
|
||||
metadata["ss_network_alpha"] = 'Dynamic'
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
@ -193,6 +325,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
||||
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
||||
parser.add_argument("--dynamic_param", type=float, default=None,
|
||||
help="Specify target for dynamic reduction")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
resize(args)
|
||||
|
@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype):
|
||||
return sd
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
def save_to_file(file_name, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
save_file(state_dict, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
torch.save(state_dict, file_name)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||
@ -76,7 +76,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
down_weight = down_weight.to(device)
|
||||
|
||||
# W <- W + U * D
|
||||
scale = (alpha / network_dim)
|
||||
scale = (alpha / network_dim).to(device)
|
||||
if not conv2d: # linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif kernel_size == (1, 1):
|
||||
@ -105,6 +105,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
mat = mat.squeeze()
|
||||
|
||||
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
||||
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
@ -114,12 +115,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
|
||||
Vh = Vh[:module_new_rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
# dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
# low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
# U = U.clamp(low_val, hi_val)
|
||||
# Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
||||
@ -156,7 +157,7 @@ def merge(args):
|
||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
save_to_file(args.save_to, state_dict, save_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -502,6 +502,14 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
|
||||
|
||||
- `--persistent_data_loader_workers`
|
||||
|
||||
Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
|
||||
|
||||
- `--max_data_loader_n_workers`
|
||||
|
||||
データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
|
||||
|
||||
- `--logging_dir` / `--log_prefix`
|
||||
|
||||
学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
|
||||
|
@ -106,6 +106,7 @@ def train(args):
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
@ -134,6 +135,8 @@ def train(args):
|
||||
gc.collect()
|
||||
|
||||
# prepare network
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
@ -175,12 +178,13 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
if is_main_process:
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
@ -251,6 +255,8 @@ def train(args):
|
||||
# 学習する
|
||||
# TODO: find a way to handle total batch size when there are multiple datasets
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
if is_main_process:
|
||||
print("running training / 学習開始")
|
||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
@ -471,6 +477,7 @@ def train(args):
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset_group.set_current_epoch(epoch + 1)
|
||||
|
||||
@ -583,6 +590,7 @@ def train(args):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
if is_main_process:
|
||||
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||
@ -594,7 +602,6 @@ def train(args):
|
||||
metadata["ss_epoch"] = str(num_train_epochs)
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
network = unwrap_model(network)
|
||||
|
||||
|
@ -64,6 +64,10 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
||||
* `--network_alpha`
|
||||
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
||||
* `--persistent_data_loader_workers`
|
||||
* Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
|
||||
* `--max_data_loader_n_workers`
|
||||
* データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
|
||||
* `--network_weights`
|
||||
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
||||
* `--network_train_unet_only`
|
||||
|
Loading…
Reference in New Issue
Block a user