2023-02-04 16:55:06 +00:00
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo and kohya
import argparse
import torch
2023-02-11 16:59:38 +00:00
from safetensors . torch import load_file , save_file , safe_open
2023-02-04 16:55:06 +00:00
from tqdm import tqdm
2023-02-11 16:59:38 +00:00
from library import train_util , model_util
2023-03-04 03:08:06 +00:00
import numpy as np
2023-02-11 16:59:38 +00:00
2023-02-04 16:55:06 +00:00
def load_state_dict ( file_name , dtype ) :
2023-02-11 16:59:38 +00:00
if model_util . is_safetensors ( file_name ) :
2023-02-04 16:55:06 +00:00
sd = load_file ( file_name )
2023-02-11 16:59:38 +00:00
with safe_open ( file_name , framework = " pt " ) as f :
metadata = f . metadata ( )
2023-02-04 16:55:06 +00:00
else :
sd = torch . load ( file_name , map_location = ' cpu ' )
2023-02-11 16:59:38 +00:00
metadata = None
2023-02-04 16:55:06 +00:00
for key in list ( sd . keys ( ) ) :
if type ( sd [ key ] ) == torch . Tensor :
sd [ key ] = sd [ key ] . to ( dtype )
2023-02-11 16:59:38 +00:00
return sd , metadata
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
def save_to_file ( file_name , model , state_dict , dtype , metadata ) :
2023-02-04 16:55:06 +00:00
if dtype is not None :
for key in list ( state_dict . keys ( ) ) :
if type ( state_dict [ key ] ) == torch . Tensor :
state_dict [ key ] = state_dict [ key ] . to ( dtype )
2023-02-11 16:59:38 +00:00
if model_util . is_safetensors ( file_name ) :
save_file ( model , file_name , metadata )
2023-02-04 16:55:06 +00:00
else :
torch . save ( model , file_name )
2023-03-04 03:08:06 +00:00
def index_sv_cumulative ( S , target ) :
original_sum = float ( torch . sum ( S ) )
cumulative_sums = torch . cumsum ( S , dim = 0 ) / original_sum
index = int ( torch . searchsorted ( cumulative_sums , target ) ) + 1
if index > = len ( S ) :
index = len ( S ) - 1
return index
def index_sv_fro ( S , target ) :
S_squared = S . pow ( 2 )
s_fro_sq = float ( torch . sum ( S_squared ) )
sum_S_squared = torch . cumsum ( S_squared , dim = 0 ) / s_fro_sq
index = int ( torch . searchsorted ( sum_S_squared , target * * 2 ) ) + 1
if index > = len ( S ) :
index = len ( S ) - 1
return index
def resize_lora_model ( lora_sd , new_rank , save_dtype , device , dynamic_method , dynamic_param , verbose ) :
2023-02-11 16:59:38 +00:00
network_alpha = None
network_dim = None
2023-02-14 23:52:08 +00:00
verbose_str = " \n "
2023-03-04 03:08:06 +00:00
fro_list = [ ]
2023-02-04 16:55:06 +00:00
2023-03-04 03:08:06 +00:00
CLAMP_QUANTILE = 1 # 0.99
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
# Extract loaded lora dim and alpha
for key , value in lora_sd . items ( ) :
if network_alpha is None and ' alpha ' in key :
network_alpha = value
if network_dim is None and ' lora_down ' in key and len ( value . size ( ) ) == 2 :
network_dim = value . size ( ) [ 0 ]
if network_alpha is not None and network_dim is not None :
break
if network_alpha is None :
network_alpha = network_dim
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
scale = network_alpha / network_dim
2023-03-04 03:08:06 +00:00
if dynamic_method :
print ( f " Dynamically determining new alphas and dims based off { dynamic_method } : { dynamic_param } " )
else :
2023-03-02 19:36:07 +00:00
new_alpha = float ( scale * new_rank ) # calculate new alpha from scale
print ( f " old dimension: { network_dim } , old alpha: { network_alpha } , new dim: { new_rank } , new alpha: { new_alpha } " )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
lora_down_weight = None
lora_up_weight = None
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
o_lora_sd = lora_sd . copy ( )
block_down_name = None
block_up_name = None
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
print ( " resizing lora... " )
with torch . no_grad ( ) :
for key , value in tqdm ( lora_sd . items ( ) ) :
if ' lora_down ' in key :
block_down_name = key . split ( " . " ) [ 0 ]
lora_down_weight = value
if ' lora_up ' in key :
block_up_name = key . split ( " . " ) [ 0 ]
lora_up_weight = value
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
weights_loaded = ( lora_down_weight is not None and lora_up_weight is not None )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
if ( block_down_name == block_up_name ) and weights_loaded :
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
conv2d = ( len ( lora_down_weight . size ( ) ) == 4 )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
if conv2d :
lora_down_weight = lora_down_weight . squeeze ( )
lora_up_weight = lora_up_weight . squeeze ( )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
if device :
org_device = lora_up_weight . device
lora_up_weight = lora_up_weight . to ( args . device )
lora_down_weight = lora_down_weight . to ( args . device )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
full_weight_matrix = torch . matmul ( lora_up_weight , lora_down_weight )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
U , S , Vh = torch . linalg . svd ( full_weight_matrix )
2023-02-04 16:55:06 +00:00
2023-03-04 03:08:06 +00:00
if dynamic_method == " sv_ratio " :
# Calculate new dim and alpha based off ratio
2023-03-02 19:36:07 +00:00
max_sv = S [ 0 ]
2023-03-04 03:08:06 +00:00
min_sv = max_sv / dynamic_param
2023-03-02 19:36:07 +00:00
new_rank = torch . sum ( S > min_sv ) . item ( )
new_rank = max ( new_rank , 1 )
new_alpha = float ( scale * new_rank )
2023-03-04 03:08:06 +00:00
elif dynamic_method == " sv_cumulative " :
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative ( S , dynamic_param )
new_rank = max ( new_rank , 1 )
new_alpha = float ( scale * new_rank )
elif dynamic_method == " sv_fro " :
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro ( S , dynamic_param )
new_rank = max ( new_rank , 1 )
new_alpha = float ( scale * new_rank )
2023-02-14 23:52:08 +00:00
if verbose :
s_sum = torch . sum ( torch . abs ( S ) )
s_rank = torch . sum ( torch . abs ( S [ : new_rank ] ) )
2023-03-04 03:08:06 +00:00
S_squared = S . pow ( 2 )
s_fro = torch . sqrt ( torch . sum ( S_squared ) )
s_red_fro = torch . sqrt ( torch . sum ( S_squared [ : new_rank ] ) )
fro_percent = float ( s_red_fro / s_fro )
if not np . isnan ( fro_percent ) :
fro_list . append ( float ( fro_percent ) )
2023-03-02 19:36:07 +00:00
verbose_str + = f " { block_down_name : 75 } | "
2023-03-04 03:08:06 +00:00
verbose_str + = f " sum(S) retained: { ( s_rank ) / s_sum : .1% } , fro retained: { fro_percent : .1% } , max(S) ratio: { S [ 0 ] / S [ new_rank ] : 0.1f } "
2023-03-02 19:36:07 +00:00
2023-03-04 03:08:06 +00:00
if verbose and dynamic_method :
verbose_str + = f " , dynamic | dim: { new_rank } , alpha: { new_alpha } \n "
2023-03-02 19:36:07 +00:00
else :
verbose_str + = f " \n "
2023-02-14 23:52:08 +00:00
2023-02-11 16:59:38 +00:00
U = U [ : , : new_rank ]
S = S [ : new_rank ]
U = U @ torch . diag ( S )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
Vh = Vh [ : new_rank , : ]
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
dist = torch . cat ( [ U . flatten ( ) , Vh . flatten ( ) ] )
hi_val = torch . quantile ( dist , CLAMP_QUANTILE )
low_val = - hi_val
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
U = U . clamp ( low_val , hi_val )
Vh = Vh . clamp ( low_val , hi_val )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
if conv2d :
U = U . unsqueeze ( 2 ) . unsqueeze ( 3 )
Vh = Vh . unsqueeze ( 2 ) . unsqueeze ( 3 )
2023-02-04 16:55:06 +00:00
2023-02-14 23:52:08 +00:00
if device :
2023-02-11 16:59:38 +00:00
U = U . to ( org_device )
Vh = Vh . to ( org_device )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
o_lora_sd [ block_down_name + " . " + " lora_down.weight " ] = Vh . to ( save_dtype ) . contiguous ( )
o_lora_sd [ block_up_name + " . " + " lora_up.weight " ] = U . to ( save_dtype ) . contiguous ( )
o_lora_sd [ block_up_name + " . " " alpha " ] = torch . tensor ( new_alpha ) . to ( save_dtype )
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
2023-02-04 16:55:06 +00:00
2023-02-14 23:52:08 +00:00
if verbose :
print ( verbose_str )
2023-03-04 03:08:06 +00:00
print ( f " Average Frobenius norm retention: { np . mean ( fro_list ) : .2% } | std: { np . std ( fro_list ) : 0.3f } " )
2023-02-11 16:59:38 +00:00
print ( " resizing complete " )
return o_lora_sd , network_dim , new_alpha
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
def resize ( args ) :
2023-02-04 16:55:06 +00:00
2023-02-11 16:59:38 +00:00
def str_to_dtype ( p ) :
if p == ' float ' :
return torch . float
if p == ' fp16 ' :
return torch . float16
if p == ' bf16 ' :
return torch . bfloat16
return None
2023-03-04 03:08:06 +00:00
if args . dynamic_method and not args . dynamic_param :
raise Exception ( " If using dynamic_method, then dynamic_param is required " )
2023-02-11 16:59:38 +00:00
merge_dtype = str_to_dtype ( ' float ' ) # matmul method above only seems to work in float32
save_dtype = str_to_dtype ( args . save_precision )
if save_dtype is None :
save_dtype = merge_dtype
print ( " loading Model... " )
lora_sd , metadata = load_state_dict ( args . model , merge_dtype )
print ( " resizing rank... " )
2023-03-04 03:08:06 +00:00
state_dict , old_dim , new_alpha = resize_lora_model ( lora_sd , args . new_rank , save_dtype , args . device , args . dynamic_method , args . dynamic_param , args . verbose )
2023-02-11 16:59:38 +00:00
# update metadata
if metadata is None :
metadata = { }
comment = metadata . get ( " ss_training_comment " , " " )
2023-03-04 03:08:06 +00:00
if not args . dynamic_method :
2023-03-02 19:36:07 +00:00
metadata [ " ss_training_comment " ] = f " dimension is resized from { old_dim } to { args . new_rank } ; { comment } "
metadata [ " ss_network_dim " ] = str ( args . new_rank )
metadata [ " ss_network_alpha " ] = str ( new_alpha )
else :
2023-03-04 03:08:06 +00:00
metadata [ " ss_training_comment " ] = f " Dynamic resize with { args . dynamic_method } : { args . dynamic_param } from { old_dim } ; { comment } "
2023-03-02 19:36:07 +00:00
metadata [ " ss_network_dim " ] = ' Dynamic '
metadata [ " ss_network_alpha " ] = ' Dynamic '
2023-02-11 16:59:38 +00:00
model_hash , legacy_hash = train_util . precalculate_safetensors_hashes ( state_dict , metadata )
metadata [ " sshs_model_hash " ] = model_hash
metadata [ " sshs_legacy_hash " ] = legacy_hash
print ( f " saving model to: { args . save_to } " )
save_to_file ( args . save_to , state_dict , state_dict , save_dtype , metadata )
2023-02-04 16:55:06 +00:00
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --save_precision " , type = str , default = None ,
choices = [ None , " float " , " fp16 " , " bf16 " ] , help = " precision in saving, float if omitted / 保存時の精度、未指定時はfloat " )
parser . add_argument ( " --new_rank " , type = int , default = 4 ,
help = " Specify rank of output LoRA / 出力するLoRAのrank (dim) " )
parser . add_argument ( " --save_to " , type = str , default = None ,
help = " destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors " )
parser . add_argument ( " --model " , type = str , default = None ,
help = " LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors " )
parser . add_argument ( " --device " , type = str , default = None , help = " device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う " )
2023-02-14 23:52:08 +00:00
parser . add_argument ( " --verbose " , action = " store_true " ,
help = " Display verbose resizing information / rank変更時の詳細情報を出力する " )
2023-03-04 03:08:06 +00:00
parser . add_argument ( " --dynamic_method " , type = str , default = None , choices = [ None , " sv_ratio " , " sv_fro " , " sv_cumulative " ] ,
help = " Specify dynamic resizing method, will override --new_rank " )
parser . add_argument ( " --dynamic_param " , type = float , default = None ,
help = " Specify target for dynamic reduction " )
2023-02-04 16:55:06 +00:00
args = parser . parse_args ( )
2023-03-02 19:36:07 +00:00
resize ( args )