From c926c9d8773146a0912a698ce11ef718b21c4e09 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Thu, 2 Mar 2023 14:36:07 -0500 Subject: [PATCH] Update Readme --- README.md | 2 +- locon | 1 + networks/extract_lora_from_models copy.py | 194 ++++++++++++++++++++++ networks/resize_lora.py | 46 +++-- 4 files changed, 231 insertions(+), 12 deletions(-) create mode 160000 locon create mode 100644 networks/extract_lora_from_models copy.py diff --git a/README.md b/README.md index 0099c4c..1dd2ee9 100644 --- a/README.md +++ b/README.md @@ -164,7 +164,7 @@ This will store your a backup file with your current locally installed pip packa ## Change History * 2023/03/02 (v21.1.0): - - Add LoCon support (https://github.com/KohakuBlueleaf/LoCon.git) to the Dreambooth LoRA tab. This will allow to create a new type of LoRA that include conv layers as part of the LoRA... hence the name LoCon. + - Add LoCon support (https://github.com/KohakuBlueleaf/LoCon.git) to the Dreambooth LoRA tab. This will allow to create a new type of LoRA that include conv layers as part of the LoRA... hence the name LoCon. LoCon will work with the native Auto1111 implementation of LoRA. If you want to use it with the Kohya_ss additionalNetwork you will need to install this other extension... until Kohya_ss support it nativelly: https://github.com/KohakuBlueleaf/a1111-sd-webui-locon * 2023/03/01 (v21.0.1): - Add warning to tensorboard start if the log information is missing - Fix issue with 8bitadam on older config file load diff --git a/locon b/locon new file mode 160000 index 0000000..143b7b1 --- /dev/null +++ b/locon @@ -0,0 +1 @@ +Subproject commit 143b7b1e33a4253b13f45526de41df748b97e585 diff --git a/networks/extract_lora_from_models copy.py b/networks/extract_lora_from_models copy.py new file mode 100644 index 0000000..aacd21b --- /dev/null +++ b/networks/extract_lora_from_models copy.py @@ -0,0 +1,194 @@ +# extract approximating LoRA by svd from two SD models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import library.model_util as model_util +import lora +import numpy as np + + +CLAMP_QUANTILE = 1 # 0.99 +MIN_DIFF = 1e-6 + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def svd(args): + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + save_dtype = str_to_dtype(args.save_precision) + + print(f"loading SD model : {args.model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + print(f"loading SD model : {args.model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + + # create LoRA network to extract weights: Use dim (rank) as alpha + lora_network_o = lora.create_network(1.0, args.dim, args.dim * 1.5, None, text_encoder_o, unet_o) + lora_network_t = lora.create_network(1.0, args.dim, args.dim * 1.5, None, text_encoder_t, unet_t) + assert len(lora_network_o.text_encoder_loras) == len( + lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " + + # get diffs + diffs = {} + text_encoder_different = False + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + + # Text Encoder might be same + if torch.max(torch.abs(diff)) > MIN_DIFF: + text_encoder_different = True + + diff = diff.float() + diffs[lora_name] = diff + + if not text_encoder_different: + print("Text encoder is same. Extract U-Net only.") + lora_network_o.text_encoder_loras = [] + diffs = {} + + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + diff = diff.float() + + if args.device: + diff = diff.to(args.device) + + diffs[lora_name] = diff + + # make LoRA with SVD + print("calculating by SVD") + rank = args.dim + lora_weights = {} + with torch.no_grad(): + for lora_name, mat in tqdm(list(diffs.items())): + conv2d = (len(mat.size()) == 4) + if conv2d: + mat = mat.squeeze() + + U, S, Vt = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vt = Vt[:rank, :] + + lora_weights[lora_name] = (U, Vt) + + # # make LoRA with svd + # print("calculating by svd") + # rank = args.dim + # lora_weights = {} + # with torch.no_grad(): + # for lora_name, mat in tqdm(list(diffs.items())): + # conv2d = (len(mat.size()) == 4) + # if conv2d: + # mat = mat.squeeze() + + # U, S, Vh = torch.linalg.svd(mat) + + # U = U[:, :rank] + # S = S[:rank] + # U = U @ torch.diag(S) + + # Vh = Vh[:rank, :] + + # # create new tensors directly from the numpy arrays + # U = torch.as_tensor(U) + # Vh = torch.as_tensor(Vh) + + # # dist = torch.cat([U.flatten(), Vh.flatten()]) + # # hi_val = torch.quantile(dist, CLAMP_QUANTILE) + # # low_val = -hi_val + + # # U = U.clamp(low_val, hi_val) + # # Vh = Vh.clamp(low_val, hi_val) + + # # # soft thresholding + # # alpha = S[-1] / 1000.0 # adjust this parameter as needed + # # U = torch.sign(U) * torch.nn.functional.relu(torch.abs(U) - alpha) + # # Vh = torch.sign(Vh) * torch.nn.functional.relu(torch.abs(Vh) - alpha) + + # lora_weights[lora_name] = (U, Vh) + + # make state dict for LoRA + lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict + lora_sd = lora_network_o.state_dict() + print(f"LoRA has {len(lora_sd)} weights.") + + for key in list(lora_sd.keys()): + if "alpha" in key: + continue + + lora_name = key.split('.')[0] + i = 0 if "lora_up" in key else 1 + + weights = lora_weights[lora_name][i] + # print(key, i, weights.size(), lora_sd[key].size()) + if len(lora_sd[key].size()) == 4: + weights = weights.unsqueeze(2).unsqueeze(3) + + assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" + lora_sd[key] = weights + + # load state dict to LoRA and save it + info = lora_network_o.load_state_dict(lora_sd) + print(f"Loading extracted LoRA weights: {info}") + + dir_name = os.path.dirname(args.save_to) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + # minimum metadata + metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim * 1.5)} + + lora_network_o.save_weights(args.save_to, save_dtype, metadata) + print(f"LoRA weights are saved to: {args.save_to}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat") + parser.add_argument("--model_org", type=str, default=None, + help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors") + parser.add_argument("--model_tuned", type=str, default=None, + help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + + args = parser.parse_args() + svd(args) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 271de8e..5ab7059 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -38,10 +38,11 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) -def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): +def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose): network_alpha = None network_dim = None verbose_str = "\n" + ratio_flag = False CLAMP_QUANTILE = 0.99 @@ -57,9 +58,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): network_alpha = network_dim scale = network_alpha/network_dim - new_alpha = float(scale*new_rank) # calculate new alpha from scale - - print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}") + if not sv_ratio: + new_alpha = float(scale*new_rank) # calculate new alpha from scale + print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}") + else: + print(f"Dynamically determining new alphas and dims based off sv ratio: {sv_ratio}") + ratio_flag = True lora_down_weight = None lora_up_weight = None @@ -97,11 +101,24 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): U, S, Vh = torch.linalg.svd(full_weight_matrix) + if ratio_flag: + # Calculate new dim and alpha for dynamic sizing + max_sv = S[0] + min_sv = max_sv/sv_ratio + new_rank = torch.sum(S > min_sv).item() + new_rank = max(new_rank, 1) + new_alpha = float(scale*new_rank) + if verbose: s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) - verbose_str+=f"{block_down_name:76} | " - verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n" + verbose_str+=f"{block_down_name:75} | " + verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}" + + if verbose and ratio_flag: + verbose_str+=f", dynamic| dim: {new_rank}, alpha: {new_alpha}\n" + else: + verbose_str+=f"\n" U = U[:, :new_rank] S = S[:new_rank] @@ -160,16 +177,21 @@ def resize(args): lora_sd, metadata = load_state_dict(args.model, merge_dtype) print("resizing rank...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose) + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.sv_ratio, args.verbose) # update metadata if metadata is None: metadata = {} comment = metadata.get("ss_training_comment", "") - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) + if not args.sv_ratio: + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = f"Dynamic resize from {old_dim} with ratio {args.sv_ratio}; {comment}" + metadata["ss_network_dim"] = 'Dynamic' + metadata["ss_network_alpha"] = 'Dynamic' model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -193,6 +215,8 @@ if __name__ == '__main__': parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する") + parser.add_argument("--sv_ratio", type=float, default=None, + help="Specify svd ratio for dim calcs. Will override --new_rank") args = parser.parse_args() - resize(args) + resize(args) \ No newline at end of file