2023-02-04 16:55:06 +00:00
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
2023-03-10 16:44:52 +00:00
# Thanks to cloneofsimo
2023-02-04 16:55:06 +00:00
import argparse
import torch
2023-02-11 16:59:38 +00:00
from safetensors . torch import load_file , save_file , safe_open
2023-02-04 16:55:06 +00:00
from tqdm import tqdm
2023-02-11 16:59:38 +00:00
from library import train_util , model_util
2023-03-10 16:44:52 +00:00
import numpy as np
2023-02-11 16:59:38 +00:00
2023-03-10 16:44:52 +00:00
MIN_SV = 1e-6
2023-02-04 16:55:06 +00:00
def load_state_dict ( file_name , dtype ) :
2023-02-11 16:59:38 +00:00
if model_util . is_safetensors ( file_name ) :
2023-02-04 16:55:06 +00:00
sd = load_file ( file_name )
2023-02-11 16:59:38 +00:00
with safe_open ( file_name , framework = " pt " ) as f :
metadata = f . metadata ( )
2023-02-04 16:55:06 +00:00
else :
sd = torch . load ( file_name , map_location = ' cpu ' )
2023-02-11 16:59:38 +00:00
metadata = None
2023-02-04 16:55:06 +00:00
for key in list ( sd . keys ( ) ) :
if type ( sd [ key ] ) == torch . Tensor :
sd [ key ] = sd [ key ] . to ( dtype )
2023-02-11 16:59:38 +00:00
return sd , metadata
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
def save_to_file ( file_name , model , state_dict , dtype , metadata ) :
2023-02-04 16:55:06 +00:00
if dtype is not None :
for key in list ( state_dict . keys ( ) ) :
if type ( state_dict [ key ] ) == torch . Tensor :
state_dict [ key ] = state_dict [ key ] . to ( dtype )
2023-02-11 16:59:38 +00:00
if model_util . is_safetensors ( file_name ) :
save_file ( model , file_name , metadata )
2023-02-04 16:55:06 +00:00
else :
torch . save ( model , file_name )
2023-03-10 16:44:52 +00:00
def index_sv_cumulative ( S , target ) :
original_sum = float ( torch . sum ( S ) )
cumulative_sums = torch . cumsum ( S , dim = 0 ) / original_sum
index = int ( torch . searchsorted ( cumulative_sums , target ) ) + 1
if index > = len ( S ) :
index = len ( S ) - 1
return index
def index_sv_fro ( S , target ) :
S_squared = S . pow ( 2 )
s_fro_sq = float ( torch . sum ( S_squared ) )
sum_S_squared = torch . cumsum ( S_squared , dim = 0 ) / s_fro_sq
index = int ( torch . searchsorted ( sum_S_squared , target * * 2 ) ) + 1
if index > = len ( S ) :
index = len ( S ) - 1
return index
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv ( weight , lora_rank , dynamic_method , dynamic_param , device , scale = 1 ) :
out_size , in_size , kernel_size , _ = weight . size ( )
U , S , Vh = torch . linalg . svd ( weight . reshape ( out_size , - 1 ) . to ( device ) )
param_dict = rank_resize ( S , lora_rank , dynamic_method , dynamic_param , scale )
lora_rank = param_dict [ " new_rank " ]
U = U [ : , : lora_rank ]
S = S [ : lora_rank ]
U = U @ torch . diag ( S )
Vh = Vh [ : lora_rank , : ]
param_dict [ " lora_down " ] = Vh . reshape ( lora_rank , in_size , kernel_size , kernel_size ) . cpu ( )
param_dict [ " lora_up " ] = U . reshape ( out_size , lora_rank , 1 , 1 ) . cpu ( )
del U , S , Vh , weight
return param_dict
def extract_linear ( weight , lora_rank , dynamic_method , dynamic_param , device , scale = 1 ) :
out_size , in_size = weight . size ( )
U , S , Vh = torch . linalg . svd ( weight . to ( device ) )
param_dict = rank_resize ( S , lora_rank , dynamic_method , dynamic_param , scale )
lora_rank = param_dict [ " new_rank " ]
U = U [ : , : lora_rank ]
S = S [ : lora_rank ]
U = U @ torch . diag ( S )
Vh = Vh [ : lora_rank , : ]
param_dict [ " lora_down " ] = Vh . reshape ( lora_rank , in_size ) . cpu ( )
param_dict [ " lora_up " ] = U . reshape ( out_size , lora_rank ) . cpu ( )
del U , S , Vh , weight
return param_dict
def merge_conv ( lora_down , lora_up , device ) :
in_rank , in_size , kernel_size , k_ = lora_down . shape
out_size , out_rank , _ , _ = lora_up . shape
assert in_rank == out_rank and kernel_size == k_ , f " rank { in_rank } { out_rank } or kernel { kernel_size } { k_ } mismatch "
lora_down = lora_down . to ( device )
lora_up = lora_up . to ( device )
merged = lora_up . reshape ( out_size , - 1 ) @ lora_down . reshape ( in_rank , - 1 )
weight = merged . reshape ( out_size , in_size , kernel_size , kernel_size )
del lora_up , lora_down
return weight
def merge_linear ( lora_down , lora_up , device ) :
in_rank , in_size = lora_down . shape
out_size , out_rank = lora_up . shape
assert in_rank == out_rank , f " rank { in_rank } { out_rank } mismatch "
lora_down = lora_down . to ( device )
lora_up = lora_up . to ( device )
weight = lora_up @ lora_down
del lora_up , lora_down
return weight
def rank_resize ( S , rank , dynamic_method , dynamic_param , scale = 1 ) :
param_dict = { }
if dynamic_method == " sv_ratio " :
# Calculate new dim and alpha based off ratio
max_sv = S [ 0 ]
min_sv = max_sv / dynamic_param
new_rank = max ( torch . sum ( S > min_sv ) . item ( ) , 1 )
new_alpha = float ( scale * new_rank )
elif dynamic_method == " sv_cumulative " :
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative ( S , dynamic_param )
new_rank = max ( new_rank , 1 )
new_alpha = float ( scale * new_rank )
elif dynamic_method == " sv_fro " :
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro ( S , dynamic_param )
new_rank = min ( max ( new_rank , 1 ) , len ( S ) - 1 )
new_alpha = float ( scale * new_rank )
else :
new_rank = rank
new_alpha = float ( scale * new_rank )
if S [ 0 ] < = MIN_SV : # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float ( scale * new_rank )
elif new_rank > rank : # cap max rank at rank
new_rank = rank
new_alpha = float ( scale * new_rank )
# Calculate resize info
s_sum = torch . sum ( torch . abs ( S ) )
s_rank = torch . sum ( torch . abs ( S [ : new_rank ] ) )
S_squared = S . pow ( 2 )
s_fro = torch . sqrt ( torch . sum ( S_squared ) )
s_red_fro = torch . sqrt ( torch . sum ( S_squared [ : new_rank ] ) )
fro_percent = float ( s_red_fro / s_fro )
param_dict [ " new_rank " ] = new_rank
param_dict [ " new_alpha " ] = new_alpha
param_dict [ " sum_retained " ] = ( s_rank ) / s_sum
param_dict [ " fro_retained " ] = fro_percent
param_dict [ " max_ratio " ] = S [ 0 ] / S [ new_rank ]
return param_dict
def resize_lora_model ( lora_sd , new_rank , save_dtype , device , dynamic_method , dynamic_param , verbose ) :
2023-02-11 16:59:38 +00:00
network_alpha = None
network_dim = None
2023-02-14 23:52:08 +00:00
verbose_str = " \n "
2023-03-10 16:44:52 +00:00
fro_list = [ ]
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
# Extract loaded lora dim and alpha
for key , value in lora_sd . items ( ) :
if network_alpha is None and ' alpha ' in key :
network_alpha = value
if network_dim is None and ' lora_down ' in key and len ( value . size ( ) ) == 2 :
network_dim = value . size ( ) [ 0 ]
if network_alpha is not None and network_dim is not None :
break
if network_alpha is None :
network_alpha = network_dim
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
scale = network_alpha / network_dim
2023-03-04 03:08:06 +00:00
2023-03-10 16:44:52 +00:00
if dynamic_method :
print ( f " Dynamically determining new alphas and dims based off { dynamic_method } : { dynamic_param } , max rank is { new_rank } " )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
lora_down_weight = None
lora_up_weight = None
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
o_lora_sd = lora_sd . copy ( )
block_down_name = None
block_up_name = None
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
with torch . no_grad ( ) :
for key , value in tqdm ( lora_sd . items ( ) ) :
2023-03-22 00:20:57 +00:00
weight_name = None
2023-02-11 16:59:38 +00:00
if ' lora_down ' in key :
block_down_name = key . split ( " . " ) [ 0 ]
2023-03-22 00:20:57 +00:00
weight_name = key . split ( " . " ) [ - 1 ]
2023-02-11 16:59:38 +00:00
lora_down_weight = value
2023-03-22 00:20:57 +00:00
else :
continue
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd . get ( block_up_name + ' .lora_up. ' + weight_name , None )
lora_alpha = lora_sd . get ( block_down_name + ' .alpha ' , None )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
weights_loaded = ( lora_down_weight is not None and lora_up_weight is not None )
2023-02-04 16:55:06 +00:00
2023-03-22 00:20:57 +00:00
if weights_loaded :
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
conv2d = ( len ( lora_down_weight . size ( ) ) == 4 )
2023-03-22 00:20:57 +00:00
if lora_alpha is None :
scale = 1.0
else :
scale = lora_alpha / lora_down_weight . size ( ) [ 0 ]
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
if conv2d :
2023-03-10 16:44:52 +00:00
full_weight_matrix = merge_conv ( lora_down_weight , lora_up_weight , device )
param_dict = extract_conv ( full_weight_matrix , new_rank , dynamic_method , dynamic_param , device , scale )
else :
full_weight_matrix = merge_linear ( lora_down_weight , lora_up_weight , device )
param_dict = extract_linear ( full_weight_matrix , new_rank , dynamic_method , dynamic_param , device , scale )
2023-02-04 16:55:06 +00:00
2023-02-14 23:52:08 +00:00
if verbose :
2023-03-10 16:44:52 +00:00
max_ratio = param_dict [ ' max_ratio ' ]
sum_retained = param_dict [ ' sum_retained ' ]
fro_retained = param_dict [ ' fro_retained ' ]
if not np . isnan ( fro_retained ) :
fro_list . append ( float ( fro_retained ) )
2023-02-14 23:52:08 +00:00
2023-03-10 16:44:52 +00:00
verbose_str + = f " { block_down_name : 75 } | "
verbose_str + = f " sum(S) retained: { sum_retained : .1% } , fro retained: { fro_retained : .1% } , max(S) ratio: { max_ratio : 0.1f } "
2023-02-04 16:55:06 +00:00
2023-03-10 16:44:52 +00:00
if verbose and dynamic_method :
verbose_str + = f " , dynamic | dim: { param_dict [ ' new_rank ' ] } , alpha: { param_dict [ ' new_alpha ' ] } \n "
else :
verbose_str + = f " \n "
2023-02-04 16:55:06 +00:00
2023-03-10 16:44:52 +00:00
new_alpha = param_dict [ ' new_alpha ' ]
o_lora_sd [ block_down_name + " . " + " lora_down.weight " ] = param_dict [ " lora_down " ] . to ( save_dtype ) . contiguous ( )
o_lora_sd [ block_up_name + " . " + " lora_up.weight " ] = param_dict [ " lora_up " ] . to ( save_dtype ) . contiguous ( )
o_lora_sd [ block_up_name + " . " " alpha " ] = torch . tensor ( param_dict [ ' new_alpha ' ] ) . to ( save_dtype )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
2023-03-10 16:44:52 +00:00
del param_dict
2023-02-04 16:55:06 +00:00
2023-02-14 23:52:08 +00:00
if verbose :
print ( verbose_str )
2023-03-10 16:44:52 +00:00
print ( f " Average Frobenius norm retention: { np . mean ( fro_list ) : .2% } | std: { np . std ( fro_list ) : 0.3f } " )
2023-02-11 16:59:38 +00:00
print ( " resizing complete " )
return o_lora_sd , network_dim , new_alpha
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
def resize ( args ) :
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
def str_to_dtype ( p ) :
if p == ' float ' :
return torch . float
if p == ' fp16 ' :
return torch . float16
if p == ' bf16 ' :
return torch . bfloat16
return None
2023-03-10 16:44:52 +00:00
if args . dynamic_method and not args . dynamic_param :
raise Exception ( " If using dynamic_method, then dynamic_param is required " )
2023-02-11 16:59:38 +00:00
merge_dtype = str_to_dtype ( ' float ' ) # matmul method above only seems to work in float32
save_dtype = str_to_dtype ( args . save_precision )
if save_dtype is None :
save_dtype = merge_dtype
print ( " loading Model... " )
lora_sd , metadata = load_state_dict ( args . model , merge_dtype )
2023-03-10 16:44:52 +00:00
print ( " Resizing Lora... " )
state_dict , old_dim , new_alpha = resize_lora_model ( lora_sd , args . new_rank , save_dtype , args . device , args . dynamic_method , args . dynamic_param , args . verbose )
2023-02-11 16:59:38 +00:00
# update metadata
if metadata is None :
metadata = { }
comment = metadata . get ( " ss_training_comment " , " " )
2023-03-10 16:44:52 +00:00
if not args . dynamic_method :
metadata [ " ss_training_comment " ] = f " dimension is resized from { old_dim } to { args . new_rank } ; { comment } "
metadata [ " ss_network_dim " ] = str ( args . new_rank )
metadata [ " ss_network_alpha " ] = str ( new_alpha )
else :
metadata [ " ss_training_comment " ] = f " Dynamic resize with { args . dynamic_method } : { args . dynamic_param } from { old_dim } ; { comment } "
metadata [ " ss_network_dim " ] = ' Dynamic '
metadata [ " ss_network_alpha " ] = ' Dynamic '
2023-02-11 16:59:38 +00:00
model_hash , legacy_hash = train_util . precalculate_safetensors_hashes ( state_dict , metadata )
metadata [ " sshs_model_hash " ] = model_hash
metadata [ " sshs_legacy_hash " ] = legacy_hash
print ( f " saving model to: { args . save_to } " )
save_to_file ( args . save_to , state_dict , state_dict , save_dtype , metadata )
2023-02-04 16:55:06 +00:00
2023-03-22 00:20:57 +00:00
def setup_parser ( ) - > argparse . ArgumentParser :
2023-02-04 16:55:06 +00:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --save_precision " , type = str , default = None ,
choices = [ None , " float " , " fp16 " , " bf16 " ] , help = " precision in saving, float if omitted / 保存時の精度、未指定時はfloat " )
parser . add_argument ( " --new_rank " , type = int , default = 4 ,
help = " Specify rank of output LoRA / 出力するLoRAのrank (dim) " )
parser . add_argument ( " --save_to " , type = str , default = None ,
help = " destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors " )
parser . add_argument ( " --model " , type = str , default = None ,
help = " LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors " )
parser . add_argument ( " --device " , type = str , default = None , help = " device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う " )
2023-02-14 23:52:08 +00:00
parser . add_argument ( " --verbose " , action = " store_true " ,
help = " Display verbose resizing information / rank変更時の詳細情報を出力する " )
2023-03-10 16:44:52 +00:00
parser . add_argument ( " --dynamic_method " , type = str , default = None , choices = [ None , " sv_ratio " , " sv_fro " , " sv_cumulative " ] ,
help = " Specify dynamic resizing method, --new_rank is used as a hard limit for max rank " )
parser . add_argument ( " --dynamic_param " , type = float , default = None ,
help = " Specify target for dynamic reduction " )
2023-03-22 00:20:57 +00:00
return parser
if __name__ == ' __main__ ' :
parser = setup_parser ( )
2023-02-04 16:55:06 +00:00
args = parser . parse_args ( )
2023-03-09 16:06:59 +00:00
resize ( args )