Update to sd-script dev code base

This commit is contained in:
bmaltais 2023-03-10 11:44:52 -05:00
parent 2deddd5f3c
commit fc5d2b2c31
12 changed files with 499 additions and 129 deletions

View File

@ -176,6 +176,9 @@ This will store your a backup file with your current locally installed pip packa
## Change History ## Change History
* 2023/03/10 (v21.2.1):
- Update to latest sd-script code
- Add support for SVD based LoRA merge
* 2023/03/09 (v21.2.0): * 2023/03/09 (v21.2.0):
- Fix issue https://github.com/bmaltais/kohya_ss/issues/335 - Fix issue https://github.com/bmaltais/kohya_ss/issues/335
- Add option to print LoRA trainer command without executing it - Add option to print LoRA trainer command without executing it

View File

@ -148,11 +148,11 @@ def gradio_resize_lora_tab():
value='fp16', value='fp16',
interactive=True, interactive=True,
) )
device = gr.Textbox( device = gr.Dropdown(
label='Device', label='Device',
placeholder='{Optional) device to use, cuda for GPU. Default: cuda', choices=['cpu', 'cuda',],
interactive=True,
value='cuda', value='cuda',
interactive=True,
) )
convert_button = gr.Button('Resize model') convert_button = gr.Button('Resize model')

View File

@ -0,0 +1,187 @@
import gradio as gr
from easygui import msgbox
import subprocess
import os
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
def svd_merge_lora(
lora_a_model,
lora_b_model,
ratio,
save_to,
precision,
save_precision,
new_rank,
new_conv_rank,
device,
):
# Check for caption_text_input
if lora_a_model == '':
msgbox('Invalid model A file')
return
if lora_b_model == '':
msgbox('Invalid model B file')
return
# Check if source model exist
if not os.path.isfile(lora_a_model):
msgbox('The provided model A is not a file')
return
if not os.path.isfile(lora_b_model):
msgbox('The provided model B is not a file')
return
ratio_a = ratio
ratio_b = 1 - ratio
run_cmd = f'{PYTHON} "{os.path.join("networks","svd_merge_lora.py")}"'
run_cmd += f' --save_precision {save_precision}'
run_cmd += f' --precision {precision}'
run_cmd += f' --save_to "{save_to}"'
run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"'
run_cmd += f' --ratios {ratio_a} {ratio_b}'
run_cmd += f' --device {device}'
run_cmd += f' --new_rank "{new_rank}"'
run_cmd += f' --new_conv_rank "{new_conv_rank}"'
print(run_cmd)
# Run the command
if os.name == 'posix':
os.system(run_cmd)
else:
subprocess.run(run_cmd)
###
# Gradio UI
###
def gradio_svd_merge_lora_tab():
with gr.Tab('Merge LoRA (SVD)'):
gr.Markdown('This utility can merge two LoRA networks together.')
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
with gr.Row():
lora_a_model = gr.Textbox(
label='LoRA model "A"',
placeholder='Path to the LoRA A model',
interactive=True,
)
button_lora_a_model_file = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_lora_a_model_file.click(
get_file_path,
inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model,
show_progress=False,
)
lora_b_model = gr.Textbox(
label='LoRA model "B"',
placeholder='Path to the LoRA B model',
interactive=True,
)
button_lora_b_model_file = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_lora_b_model_file.click(
get_file_path,
inputs=[lora_b_model, lora_ext, lora_ext_name],
outputs=lora_b_model,
show_progress=False,
)
with gr.Row():
ratio = gr.Slider(
label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
minimum=0,
maximum=1,
step=0.01,
value=0.5,
interactive=True,
)
new_rank = gr.Slider(
label='New Rank',
minimum=1,
maximum=1024,
step=1,
value=128,
interactive=True,
)
new_conv_rank = gr.Slider(
label='New Conv Rank',
minimum=1,
maximum=1024,
step=1,
value=128,
interactive=True,
)
with gr.Row():
save_to = gr.Textbox(
label='Save to',
placeholder='path for the file to save...',
interactive=True,
)
button_save_to = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_save_to.click(
get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
show_progress=False,
)
precision = gr.Dropdown(
label='Merge precision',
choices=['fp16', 'bf16', 'float'],
value='float',
interactive=True,
)
save_precision = gr.Dropdown(
label='Save precision',
choices=['fp16', 'bf16', 'float'],
value='float',
interactive=True,
)
device = gr.Dropdown(
label='Device',
choices=['cpu', 'cuda',],
value='cuda',
interactive=True,
)
convert_button = gr.Button('Merge model')
convert_button.click(
svd_merge_lora,
inputs=[
lora_a_model,
lora_b_model,
ratio,
save_to,
precision,
save_precision,
new_rank,
new_conv_rank,
device,
],
show_progress=False,
)

View File

@ -911,6 +911,10 @@ class FineTuningDataset(BaseDataset):
# path情報を作る # path情報を作る
if os.path.exists(image_key): if os.path.exists(image_key):
abs_path = image_key abs_path = image_key
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
abs_path = npz_path
else: else:
# わりといい加減だがいい方法が思いつかん # わりといい加減だがいい方法が思いつかん
abs_path = glob_images(subset.image_dir, image_key) abs_path = glob_images(subset.image_dir, image_key)
@ -1757,15 +1761,22 @@ def get_optimizer(args, trainable_params):
raise ImportError("No dadaptation / dadaptation がインストールされていないようです") raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
min_lr = lr actual_lr = lr
lr_count = 1
if type(trainable_params) == list and type(trainable_params[0]) == dict: if type(trainable_params) == list and type(trainable_params[0]) == dict:
lrs = set()
actual_lr = trainable_params[0].get("lr", actual_lr)
for group in trainable_params: for group in trainable_params:
min_lr = min(min_lr, group.get("lr", lr)) lrs.add(group.get("lr", actual_lr))
lr_count = len(lrs)
if min_lr <= 0.1: if actual_lr <= 0.1:
print( print(
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}') f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}')
print('recommend option: lr=1.0 / 推奨は1.0です') print('recommend option: lr=1.0 / 推奨は1.0です')
if lr_count > 1:
print(
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合Text EncoderとU-Netなど、最初の学習率のみが有効になります: lr={actual_lr}")
optimizer_class = dadaptation.DAdaptAdam optimizer_class = dadaptation.DAdaptAdam
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
@ -2296,6 +2307,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
with torch.no_grad(): with torch.no_grad():
with accelerator.autocast(): with accelerator.autocast():
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if not accelerator.is_main_process:
continue
prompt = prompt.strip() prompt = prompt.strip()
if len(prompt) == 0 or prompt[0] == '#': if len(prompt) == 0 or prompt[0] == '#':
continue continue
@ -2355,6 +2368,12 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
height = max(64, height - height % 8) # round to divisible by 8 height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())

View File

@ -38,6 +38,7 @@ from library.tensorboard_gui import (
from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab from library.utilities import utilities_tab
from library.merge_lora_gui import gradio_merge_lora_tab from library.merge_lora_gui import gradio_merge_lora_tab
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
from library.verify_lora_gui import gradio_verify_lora_tab from library.verify_lora_gui import gradio_verify_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab from library.resize_lora_gui import gradio_resize_lora_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample from library.sampler_gui import sample_gradio_config, run_cmd_sample
@ -879,6 +880,7 @@ def lora_tab(
) )
gradio_dataset_balancing_tab() gradio_dataset_balancing_tab()
gradio_merge_lora_tab() gradio_merge_lora_tab()
gradio_svd_merge_lora_tab()
gradio_resize_lora_tab() gradio_resize_lora_tab()
gradio_verify_lora_tab() gradio_verify_lora_tab()

View File

@ -103,7 +103,8 @@ def svd(args):
if args.device: if args.device:
mat = mat.to(args.device) mat = mat.to(args.device)
# print(mat.size(), mat.device, rank, in_dim, out_dim)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
if conv2d: if conv2d:
@ -112,7 +113,7 @@ def svd(args):
else: else:
mat = mat.squeeze() mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat) U, S, Vh = torch.linalg.svd(mat.to("cuda"))
U = U[:, :rank] U = U[:, :rank]
S = S[:rank] S = S[:rank]
@ -137,27 +138,17 @@ def svd(args):
lora_weights[lora_name] = (U, Vh) lora_weights[lora_name] = (U, Vh)
# make state dict for LoRA # make state dict for LoRA
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict lora_sd = {}
lora_sd = lora_network_o.state_dict() for lora_name, (up_weight, down_weight) in lora_weights.items():
print(f"LoRA has {len(lora_sd)} weights.") lora_sd[lora_name + '.lora_up.weight'] = up_weight
lora_sd[lora_name + '.lora_down.weight'] = down_weight
for key in list(lora_sd.keys()): lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
if "alpha" in key:
continue
lora_name = key.split('.')[0]
i = 0 if "lora_up" in key else 1
weights = lora_weights[lora_name][i]
# print(key, i, weights.size(), lora_sd[key].size())
# if len(lora_sd[key].size()) == 4:
# weights = weights.unsqueeze(2).unsqueeze(3)
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
lora_sd[key] = weights
# load state dict to LoRA and save it # load state dict to LoRA and save it
info = lora_network_o.load_state_dict(lora_sd) lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}") print(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(args.save_to) dir_name = os.path.dirname(args.save_to)
@ -167,7 +158,7 @@ def svd(args):
# minimum metadata # minimum metadata
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
lora_network_o.save_weights(args.save_to, save_dtype, metadata) lora_network_save.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {args.save_to}")

View File

@ -21,30 +21,34 @@ class LoRAModule(torch.nn.Module):
""" if alpha == 0 or None, alpha is rank (no scaling). """ """ if alpha == 0 or None, alpha is rank (no scaling). """
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels in_dim = org_module.in_channels
out_dim = org_module.out_channels out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = min(self.lora_dim, in_dim, out_dim) # if limit_rank:
if self.lora_dim != lora_dim: # self.lora_dim = min(lora_dim, in_dim, out_dim)
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d':
kernel_size = org_module.kernel_size kernel_size = org_module.kernel_size
stride = org_module.stride stride = org_module.stride
padding = org_module.padding padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else: else:
in_dim = org_module.in_features self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
out_dim = org_module.out_features self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor: if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
@ -149,7 +153,8 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
return network return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open
weights_sd = load_file(file) weights_sd = load_file(file)
@ -183,7 +188,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
# is it possible to apply conv_in and conv_out? # is it possible to apply conv_in and conv_out?
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"] UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = 'lora_te'
@ -245,7 +251,12 @@ class LoRANetwork(torch.nn.Module):
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
if modules_dim is not None or self.conv_lora_dim is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
self.weights_sd = None self.weights_sd = None

View File

@ -1,14 +1,15 @@
# Convert LoRA to different rank approximation (should only be used to go to lower rank) # Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo and kohya # Thanks to cloneofsimo
import argparse import argparse
import os
import torch import torch
from safetensors.torch import load_file, save_file, safe_open from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm from tqdm import tqdm
from library import train_util, model_util from library import train_util, model_util
import numpy as np
MIN_SV = 1e-6
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name): if model_util.is_safetensors(file_name):
@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
torch.save(model, file_name) torch.save(model, file_name)
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
if index >= len(S):
index = len(S) - 1
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
s_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
if index >= len(S):
index = len(S) - 1
return index
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size, kernel_size, _ = weight.size()
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
del U, S, Vh, weight
return param_dict
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size = weight.size()
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
del U, S, Vh, weight
return param_dict
def merge_conv(lora_down, lora_up, device):
in_rank, in_size, kernel_size, k_ = lora_down.shape
out_size, out_rank, _, _ = lora_up.shape
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
del lora_up, lora_down
return weight
def merge_linear(lora_down, lora_up, device):
in_rank, in_size = lora_down.shape
out_size, out_rank = lora_up.shape
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
weight = lora_up @ lora_down
del lora_up, lora_down
return weight
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {}
if dynamic_method=="sv_ratio":
# Calculate new dim and alpha based off ratio
max_sv = S[0]
min_sv = max_sv/dynamic_param
new_rank = max(torch.sum(S > min_sv).item(),1)
new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_cumulative":
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param)
new_rank = max(new_rank, 1)
new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param)
new_rank = min(max(new_rank, 1), len(S)-1)
new_alpha = float(scale*new_rank)
else:
new_rank = rank
new_alpha = float(scale*new_rank)
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float(scale*new_rank)
elif new_rank > rank: # cap max rank at rank
new_rank = rank
new_alpha = float(scale*new_rank)
# Calculate resize info
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro/s_fro)
param_dict["new_rank"] = new_rank
param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank]
return param_dict
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None network_alpha = None
network_dim = None network_dim = None
verbose_str = "\n" verbose_str = "\n"
fro_list = []
CLAMP_QUANTILE = 0.99
# Extract loaded lora dim and alpha # Extract loaded lora dim and alpha
for key, value in lora_sd.items(): for key, value in lora_sd.items():
@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
network_alpha = network_dim network_alpha = network_dim
scale = network_alpha/network_dim scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}") if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
lora_down_weight = None lora_down_weight = None
lora_up_weight = None lora_up_weight = None
@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
block_down_name = None block_down_name = None
block_up_name = None block_up_name = None
print("resizing lora...")
with torch.no_grad(): with torch.no_grad():
for key, value in tqdm(lora_sd.items()): for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key: if 'lora_down' in key:
@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
conv2d = (len(lora_down_weight.size()) == 4) conv2d = (len(lora_down_weight.size()) == 4)
if conv2d: if conv2d:
lora_down_weight = lora_down_weight.squeeze() full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
lora_up_weight = lora_up_weight.squeeze() param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
else:
if device: full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
org_device = lora_up_weight.device param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
U, S, Vh = torch.linalg.svd(full_weight_matrix)
if verbose: if verbose:
s_sum = torch.sum(torch.abs(S)) max_ratio = param_dict['max_ratio']
s_rank = torch.sum(torch.abs(S[:new_rank])) sum_retained = param_dict['sum_retained']
verbose_str+=f"{block_down_name:76} | " fro_retained = param_dict['fro_retained']
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n" if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
U = U[:, :new_rank] verbose_str+=f"{block_down_name:75} | "
S = S[:new_rank] verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :] if verbose and dynamic_method:
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str+=f"\n"
dist = torch.cat([U.flatten(), Vh.flatten()]) new_alpha = param_dict['new_alpha']
hi_val = torch.quantile(dist, CLAMP_QUANTILE) o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
low_val = -hi_val o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3)
if device:
U = U.to(org_device)
Vh = Vh.to(org_device)
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
block_down_name = None block_down_name = None
block_up_name = None block_up_name = None
lora_down_weight = None lora_down_weight = None
lora_up_weight = None lora_up_weight = None
weights_loaded = False weights_loaded = False
del param_dict
if verbose: if verbose:
print(verbose_str) print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
print("resizing complete") print("resizing complete")
return o_lora_sd, network_dim, new_alpha return o_lora_sd, network_dim, new_alpha
@ -151,6 +274,9 @@ def resize(args):
return torch.bfloat16 return torch.bfloat16
return None return None
if args.dynamic_method and not args.dynamic_param:
raise Exception("If using dynamic_method, then dynamic_param is required")
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision) save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None: if save_dtype is None:
@ -159,17 +285,23 @@ def resize(args):
print("loading Model...") print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype) lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print("resizing rank...") print("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose) state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
# update metadata # update metadata
if metadata is None: if metadata is None:
metadata = {} metadata = {}
comment = metadata.get("ss_training_comment", "") comment = metadata.get("ss_training_comment", "")
if not args.dynamic_method:
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank) metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha) metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
metadata["ss_network_dim"] = 'Dynamic'
metadata["ss_network_alpha"] = 'Dynamic'
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash metadata["sshs_model_hash"] = model_hash
@ -193,6 +325,11 @@ if __name__ == '__main__':
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true", parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information / rank変更時の詳細情報を出力する") help="Display verbose resizing information / rank変更時の詳細情報を出力する")
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
parser.add_argument("--dynamic_param", type=float, default=None,
help="Specify target for dynamic reduction")
args = parser.parse_args() args = parser.parse_args()
resize(args) resize(args)

View File

@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype):
return sd return sd
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, state_dict, dtype):
if dtype is not None: if dtype is not None:
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name) save_file(state_dict, file_name)
else: else:
torch.save(model, file_name) torch.save(state_dict, file_name)
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
@ -76,7 +76,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
down_weight = down_weight.to(device) down_weight = down_weight.to(device)
# W <- W + U * D # W <- W + U * D
scale = (alpha / network_dim) scale = (alpha / network_dim).to(device)
if not conv2d: # linear if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale weight = weight + ratio * (up_weight @ down_weight) * scale
elif kernel_size == (1, 1): elif kernel_size == (1, 1):
@ -105,6 +105,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
mat = mat.squeeze() mat = mat.squeeze()
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
U, S, Vh = torch.linalg.svd(mat) U, S, Vh = torch.linalg.svd(mat)
@ -114,12 +115,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
Vh = Vh[:module_new_rank, :] Vh = Vh[:module_new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()]) # dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE) # hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val # low_val = -hi_val
U = U.clamp(low_val, hi_val) # U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val) # Vh = Vh.clamp(low_val, hi_val)
if conv2d: if conv2d:
U = U.reshape(out_dim, module_new_rank, 1, 1) U = U.reshape(out_dim, module_new_rank, 1, 1)
@ -156,7 +157,7 @@ def merge(args):
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
print(f"saving model to: {args.save_to}") print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, save_dtype)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -502,6 +502,14 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。 clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
- `--persistent_data_loader_workers`
Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
- `--max_data_loader_n_workers`
データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
- `--logging_dir` / `--log_prefix` - `--logging_dir` / `--log_prefix`
学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。 学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。

View File

@ -106,6 +106,7 @@ def train(args):
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator, unwrap_model = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
@ -134,6 +135,8 @@ def train(args):
gc.collect() gc.collect()
# prepare network # prepare network
import sys
sys.path.append(os.path.dirname(__file__))
print("import network module:", args.network_module) print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module) network_module = importlib.import_module(args.network_module)
@ -175,12 +178,13 @@ def train(args):
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
@ -251,6 +255,8 @@ def train(args):
# 学習する # 学習する
# TODO: find a way to handle total batch size when there are multiple datasets # TODO: find a way to handle total batch size when there are multiple datasets
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
if is_main_process:
print("running training / 学習開始") print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
@ -471,6 +477,7 @@ def train(args):
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) train_dataset_group.set_current_epoch(epoch + 1)
@ -583,6 +590,7 @@ def train(args):
print(f"removing old checkpoint: {old_ckpt_file}") print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file) os.remove(old_ckpt_file)
if is_main_process:
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state: if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
@ -594,7 +602,6 @@ def train(args):
metadata["ss_epoch"] = str(num_train_epochs) metadata["ss_epoch"] = str(num_train_epochs)
metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_training_finished_at"] = str(time.time())
is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
network = unwrap_model(network) network = unwrap_model(network)

View File

@ -64,6 +64,10 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
* LoRAのRANKを指定します``--networkdim=4``など。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。 * LoRAのRANKを指定します``--networkdim=4``など。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
* `--network_alpha` * `--network_alpha`
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。 * アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
* `--persistent_data_loader_workers`
* Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
* `--max_data_loader_n_workers`
* データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
* `--network_weights` * `--network_weights`
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。 * 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
* `--network_train_unet_only` * `--network_train_unet_only`