Fix again the custom model config load
Update resize lora
This commit is contained in:
parent
c61ad5f8f9
commit
4c1448be72
@ -272,7 +272,7 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_pretrained_model_name_or_path_input(value, pretrained_model_name_or_path, v2, v_parameterization):
|
def set_pretrained_model_name_or_path_input(model_list, pretrained_model_name_or_path, v2, v_parameterization):
|
||||||
# define a list of substrings to search for
|
# define a list of substrings to search for
|
||||||
substrings_v2 = [
|
substrings_v2 = [
|
||||||
'stabilityai/stable-diffusion-2-1-base',
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
@ -280,12 +280,12 @@ def set_pretrained_model_name_or_path_input(value, pretrained_model_name_or_path
|
|||||||
]
|
]
|
||||||
|
|
||||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||||
if str(value) in substrings_v2:
|
if str(model_list) in substrings_v2:
|
||||||
print('SD v2 model detected. Setting --v2 parameter')
|
print('SD v2 model detected. Setting --v2 parameter')
|
||||||
v2 = True
|
v2 = True
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
|
|
||||||
return value, v2, v_parameterization
|
return model_list, v2, v_parameterization
|
||||||
|
|
||||||
# define a list of substrings to search for v-objective
|
# define a list of substrings to search for v-objective
|
||||||
substrings_v_parameterization = [
|
substrings_v_parameterization = [
|
||||||
@ -294,14 +294,14 @@ def set_pretrained_model_name_or_path_input(value, pretrained_model_name_or_path
|
|||||||
]
|
]
|
||||||
|
|
||||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
||||||
if str(value) in substrings_v_parameterization:
|
if str(model_list) in substrings_v_parameterization:
|
||||||
print(
|
print(
|
||||||
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
|
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
|
||||||
)
|
)
|
||||||
v2 = True
|
v2 = True
|
||||||
v_parameterization = True
|
v_parameterization = True
|
||||||
|
|
||||||
return value, v2, v_parameterization
|
return model_list, v2, v_parameterization
|
||||||
|
|
||||||
# define a list of substrings to v1.x
|
# define a list of substrings to v1.x
|
||||||
substrings_v1_model = [
|
substrings_v1_model = [
|
||||||
@ -309,19 +309,18 @@ def set_pretrained_model_name_or_path_input(value, pretrained_model_name_or_path
|
|||||||
'runwayml/stable-diffusion-v1-5',
|
'runwayml/stable-diffusion-v1-5',
|
||||||
]
|
]
|
||||||
|
|
||||||
if str(value) in substrings_v1_model:
|
if str(model_list) in substrings_v1_model:
|
||||||
v2 = False
|
v2 = False
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
|
|
||||||
return value, v2, v_parameterization
|
return model_list, v2, v_parameterization
|
||||||
|
|
||||||
if value == 'custom':
|
if model_list == 'custom':
|
||||||
if str(pretrained_model_name_or_path) in substrings_v1_model or str(pretrained_model_name_or_path) in substrings_v2 or str(pretrained_model_name_or_path) in substrings_v_parameterization:
|
if str(pretrained_model_name_or_path) in substrings_v1_model or str(pretrained_model_name_or_path) in substrings_v2 or str(pretrained_model_name_or_path) in substrings_v_parameterization:
|
||||||
value = ''
|
pretrained_model_name_or_path = ''
|
||||||
v2 = False
|
v2 = False
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
|
return pretrained_model_name_or_path, v2, v_parameterization
|
||||||
return value, v2, v_parameterization
|
|
||||||
|
|
||||||
###
|
###
|
||||||
### Gradio common GUI section
|
### Gradio common GUI section
|
||||||
|
@ -3,11 +3,11 @@
|
|||||||
# Thanks to cloneofsimo and kohya
|
# Thanks to cloneofsimo and kohya
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file, safe_open
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from library import train_util, model_util
|
from library import train_util, model_util
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
@ -38,13 +38,34 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose):
|
def index_sv_cumulative(S, target):
|
||||||
|
original_sum = float(torch.sum(S))
|
||||||
|
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
||||||
|
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||||
|
if index >= len(S):
|
||||||
|
index = len(S) - 1
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def index_sv_fro(S, target):
|
||||||
|
S_squared = S.pow(2)
|
||||||
|
s_fro_sq = float(torch.sum(S_squared))
|
||||||
|
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
||||||
|
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||||
|
if index >= len(S):
|
||||||
|
index = len(S) - 1
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||||
network_alpha = None
|
network_alpha = None
|
||||||
network_dim = None
|
network_dim = None
|
||||||
verbose_str = "\n"
|
verbose_str = "\n"
|
||||||
ratio_flag = False
|
fro_list = []
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 1 # 0.99
|
||||||
|
|
||||||
# Extract loaded lora dim and alpha
|
# Extract loaded lora dim and alpha
|
||||||
for key, value in lora_sd.items():
|
for key, value in lora_sd.items():
|
||||||
@ -58,12 +79,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose):
|
|||||||
network_alpha = network_dim
|
network_alpha = network_dim
|
||||||
|
|
||||||
scale = network_alpha/network_dim
|
scale = network_alpha/network_dim
|
||||||
if not sv_ratio:
|
|
||||||
|
if dynamic_method:
|
||||||
|
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}")
|
||||||
|
else:
|
||||||
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
||||||
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}")
|
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}")
|
||||||
else:
|
|
||||||
print(f"Dynamically determining new alphas and dims based off sv ratio: {sv_ratio}")
|
|
||||||
ratio_flag = True
|
|
||||||
|
|
||||||
lora_down_weight = None
|
lora_down_weight = None
|
||||||
lora_up_weight = None
|
lora_up_weight = None
|
||||||
@ -101,22 +122,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose):
|
|||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
||||||
|
|
||||||
if ratio_flag:
|
if dynamic_method=="sv_ratio":
|
||||||
# Calculate new dim and alpha for dynamic sizing
|
# Calculate new dim and alpha based off ratio
|
||||||
max_sv = S[0]
|
max_sv = S[0]
|
||||||
min_sv = max_sv/sv_ratio
|
min_sv = max_sv/dynamic_param
|
||||||
new_rank = torch.sum(S > min_sv).item()
|
new_rank = torch.sum(S > min_sv).item()
|
||||||
new_rank = max(new_rank, 1)
|
new_rank = max(new_rank, 1)
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
elif dynamic_method=="sv_cumulative":
|
||||||
|
# Calculate new dim and alpha based off cumulative sum
|
||||||
|
new_rank = index_sv_cumulative(S, dynamic_param)
|
||||||
|
new_rank = max(new_rank, 1)
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
elif dynamic_method=="sv_fro":
|
||||||
|
# Calculate new dim and alpha based off sqrt sum of squares
|
||||||
|
new_rank = index_sv_fro(S, dynamic_param)
|
||||||
|
new_rank = max(new_rank, 1)
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
s_sum = torch.sum(torch.abs(S))
|
s_sum = torch.sum(torch.abs(S))
|
||||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||||
verbose_str+=f"{block_down_name:75} | "
|
|
||||||
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}"
|
|
||||||
|
|
||||||
if verbose and ratio_flag:
|
S_squared = S.pow(2)
|
||||||
verbose_str+=f", dynamic| dim: {new_rank}, alpha: {new_alpha}\n"
|
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||||
|
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||||
|
fro_percent = float(s_red_fro/s_fro)
|
||||||
|
if not np.isnan(fro_percent):
|
||||||
|
fro_list.append(float(fro_percent))
|
||||||
|
|
||||||
|
verbose_str+=f"{block_down_name:75} | "
|
||||||
|
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, fro retained: {fro_percent:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}"
|
||||||
|
|
||||||
|
|
||||||
|
if verbose and dynamic_method:
|
||||||
|
verbose_str+=f", dynamic | dim: {new_rank}, alpha: {new_alpha}\n"
|
||||||
else:
|
else:
|
||||||
verbose_str+=f"\n"
|
verbose_str+=f"\n"
|
||||||
|
|
||||||
@ -153,6 +195,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose):
|
|||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(verbose_str)
|
print(verbose_str)
|
||||||
|
|
||||||
|
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||||
print("resizing complete")
|
print("resizing complete")
|
||||||
return o_lora_sd, network_dim, new_alpha
|
return o_lora_sd, network_dim, new_alpha
|
||||||
|
|
||||||
@ -168,6 +212,9 @@ def resize(args):
|
|||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if args.dynamic_method and not args.dynamic_param:
|
||||||
|
raise Exception("If using dynamic_method, then dynamic_param is required")
|
||||||
|
|
||||||
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
||||||
save_dtype = str_to_dtype(args.save_precision)
|
save_dtype = str_to_dtype(args.save_precision)
|
||||||
if save_dtype is None:
|
if save_dtype is None:
|
||||||
@ -177,19 +224,20 @@ def resize(args):
|
|||||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||||
|
|
||||||
print("resizing rank...")
|
print("resizing rank...")
|
||||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.sv_ratio, args.verbose)
|
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
||||||
|
|
||||||
# update metadata
|
# update metadata
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
comment = metadata.get("ss_training_comment", "")
|
comment = metadata.get("ss_training_comment", "")
|
||||||
if not args.sv_ratio:
|
|
||||||
|
if not args.dynamic_method:
|
||||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||||
metadata["ss_network_dim"] = str(args.new_rank)
|
metadata["ss_network_dim"] = str(args.new_rank)
|
||||||
metadata["ss_network_alpha"] = str(new_alpha)
|
metadata["ss_network_alpha"] = str(new_alpha)
|
||||||
else:
|
else:
|
||||||
metadata["ss_training_comment"] = f"Dynamic resize from {old_dim} with ratio {args.sv_ratio}; {comment}"
|
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
||||||
metadata["ss_network_dim"] = 'Dynamic'
|
metadata["ss_network_dim"] = 'Dynamic'
|
||||||
metadata["ss_network_alpha"] = 'Dynamic'
|
metadata["ss_network_alpha"] = 'Dynamic'
|
||||||
|
|
||||||
@ -215,8 +263,11 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
parser.add_argument("--verbose", action="store_true",
|
parser.add_argument("--verbose", action="store_true",
|
||||||
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
||||||
parser.add_argument("--sv_ratio", type=float, default=None,
|
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
||||||
help="Specify svd ratio for dim calcs. Will override --new_rank")
|
help="Specify dynamic resizing method, will override --new_rank")
|
||||||
|
parser.add_argument("--dynamic_param", type=float, default=None,
|
||||||
|
help="Specify target for dynamic reduction")
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
resize(args)
|
resize(args)
|
Loading…
Reference in New Issue
Block a user