diff --git a/library/common_gui.py b/library/common_gui.py index 497be05..331ad07 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -272,7 +272,7 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): ) -def set_pretrained_model_name_or_path_input(value, pretrained_model_name_or_path, v2, v_parameterization): +def set_pretrained_model_name_or_path_input(model_list, pretrained_model_name_or_path, v2, v_parameterization): # define a list of substrings to search for substrings_v2 = [ 'stabilityai/stable-diffusion-2-1-base', @@ -280,12 +280,12 @@ def set_pretrained_model_name_or_path_input(value, pretrained_model_name_or_path ] # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list - if str(value) in substrings_v2: + if str(model_list) in substrings_v2: print('SD v2 model detected. Setting --v2 parameter') v2 = True v_parameterization = False - return value, v2, v_parameterization + return model_list, v2, v_parameterization # define a list of substrings to search for v-objective substrings_v_parameterization = [ @@ -294,14 +294,14 @@ def set_pretrained_model_name_or_path_input(value, pretrained_model_name_or_path ] # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list - if str(value) in substrings_v_parameterization: + if str(model_list) in substrings_v_parameterization: print( 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' ) v2 = True v_parameterization = True - return value, v2, v_parameterization + return model_list, v2, v_parameterization # define a list of substrings to v1.x substrings_v1_model = [ @@ -309,19 +309,18 @@ def set_pretrained_model_name_or_path_input(value, pretrained_model_name_or_path 'runwayml/stable-diffusion-v1-5', ] - if str(value) in substrings_v1_model: + if str(model_list) in substrings_v1_model: v2 = False v_parameterization = False - return value, v2, v_parameterization + return model_list, v2, v_parameterization - if value == 'custom': + if model_list == 'custom': if str(pretrained_model_name_or_path) in substrings_v1_model or str(pretrained_model_name_or_path) in substrings_v2 or str(pretrained_model_name_or_path) in substrings_v_parameterization: - value = '' + pretrained_model_name_or_path = '' v2 = False v_parameterization = False - - return value, v2, v_parameterization + return pretrained_model_name_or_path, v2, v_parameterization ### ### Gradio common GUI section diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 5ab7059..7d127ad 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -3,11 +3,11 @@ # Thanks to cloneofsimo and kohya import argparse -import os import torch from safetensors.torch import load_file, save_file, safe_open from tqdm import tqdm from library import train_util, model_util +import numpy as np def load_state_dict(file_name, dtype): @@ -38,13 +38,34 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) -def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose): +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0)/original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + if index >= len(S): + index = len(S) - 1 + + return index + + +def index_sv_fro(S, target): + S_squared = S.pow(2) + s_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + if index >= len(S): + index = len(S) - 1 + + return index + + +def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): network_alpha = None network_dim = None verbose_str = "\n" - ratio_flag = False + fro_list = [] - CLAMP_QUANTILE = 0.99 + CLAMP_QUANTILE = 1 # 0.99 # Extract loaded lora dim and alpha for key, value in lora_sd.items(): @@ -58,12 +79,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose): network_alpha = network_dim scale = network_alpha/network_dim - if not sv_ratio: + + if dynamic_method: + print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}") + else: new_alpha = float(scale*new_rank) # calculate new alpha from scale print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}") - else: - print(f"Dynamically determining new alphas and dims based off sv ratio: {sv_ratio}") - ratio_flag = True lora_down_weight = None lora_up_weight = None @@ -101,22 +122,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose): U, S, Vh = torch.linalg.svd(full_weight_matrix) - if ratio_flag: - # Calculate new dim and alpha for dynamic sizing + if dynamic_method=="sv_ratio": + # Calculate new dim and alpha based off ratio max_sv = S[0] - min_sv = max_sv/sv_ratio + min_sv = max_sv/dynamic_param new_rank = torch.sum(S > min_sv).item() new_rank = max(new_rank, 1) new_alpha = float(scale*new_rank) + elif dynamic_method=="sv_cumulative": + # Calculate new dim and alpha based off cumulative sum + new_rank = index_sv_cumulative(S, dynamic_param) + new_rank = max(new_rank, 1) + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_fro": + # Calculate new dim and alpha based off sqrt sum of squares + new_rank = index_sv_fro(S, dynamic_param) + new_rank = max(new_rank, 1) + new_alpha = float(scale*new_rank) + if verbose: s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) - verbose_str+=f"{block_down_name:75} | " - verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}" + + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) + fro_percent = float(s_red_fro/s_fro) + if not np.isnan(fro_percent): + fro_list.append(float(fro_percent)) - if verbose and ratio_flag: - verbose_str+=f", dynamic| dim: {new_rank}, alpha: {new_alpha}\n" + verbose_str+=f"{block_down_name:75} | " + verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, fro retained: {fro_percent:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}" + + + if verbose and dynamic_method: + verbose_str+=f", dynamic | dim: {new_rank}, alpha: {new_alpha}\n" else: verbose_str+=f"\n" @@ -153,6 +195,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose): if verbose: print(verbose_str) + + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") print("resizing complete") return o_lora_sd, network_dim, new_alpha @@ -168,6 +212,9 @@ def resize(args): return torch.bfloat16 return None + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 save_dtype = str_to_dtype(args.save_precision) if save_dtype is None: @@ -177,19 +224,20 @@ def resize(args): lora_sd, metadata = load_state_dict(args.model, merge_dtype) print("resizing rank...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.sv_ratio, args.verbose) + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) # update metadata if metadata is None: metadata = {} comment = metadata.get("ss_training_comment", "") - if not args.sv_ratio: + + if not args.dynamic_method: metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" metadata["ss_network_dim"] = str(args.new_rank) metadata["ss_network_alpha"] = str(new_alpha) else: - metadata["ss_training_comment"] = f"Dynamic resize from {old_dim} with ratio {args.sv_ratio}; {comment}" + metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" metadata["ss_network_dim"] = 'Dynamic' metadata["ss_network_alpha"] = 'Dynamic' @@ -215,8 +263,11 @@ if __name__ == '__main__': parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する") - parser.add_argument("--sv_ratio", type=float, default=None, - help="Specify svd ratio for dim calcs. Will override --new_rank") + parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, will override --new_rank") + parser.add_argument("--dynamic_param", type=float, default=None, + help="Specify target for dynamic reduction") + args = parser.parse_args() resize(args) \ No newline at end of file