From fc5d2b2c31c7170c91e82e66b2b6b3024fce3918 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Fri, 10 Mar 2023 11:44:52 -0500 Subject: [PATCH 1/3] Update to sd-script dev code base --- README.md | 3 + library/resize_lora_gui.py | 6 +- library/svd_merge_lora_gui.py | 187 +++++++++++++++++++++ library/train_util.py | 35 +++- lora_gui.py | 2 + networks/extract_lora_from_models.py | 35 ++-- networks/lora.py | 49 +++--- networks/resize_lora.py | 237 +++++++++++++++++++++------ networks/svd_merge_lora.py | 21 +-- train_README-ja.md | 8 + train_network.py | 41 +++-- train_network_README-ja.md | 4 + 12 files changed, 499 insertions(+), 129 deletions(-) create mode 100644 library/svd_merge_lora_gui.py diff --git a/README.md b/README.md index a468a4e..40a7b4a 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,9 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/03/10 (v21.2.1): + - Update to latest sd-script code + - Add support for SVD based LoRA merge * 2023/03/09 (v21.2.0): - Fix issue https://github.com/bmaltais/kohya_ss/issues/335 - Add option to print LoRA trainer command without executing it diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index a94fdb7..6b4396b 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -148,11 +148,11 @@ def gradio_resize_lora_tab(): value='fp16', interactive=True, ) - device = gr.Textbox( + device = gr.Dropdown( label='Device', - placeholder='{Optional) device to use, cuda for GPU. Default: cuda', - interactive=True, + choices=['cpu', 'cuda',], value='cuda', + interactive=True, ) convert_button = gr.Button('Resize model') diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py new file mode 100644 index 0000000..b34b503 --- /dev/null +++ b/library/svd_merge_lora_gui.py @@ -0,0 +1,187 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +document_symbol = '\U0001F4C4' # 📄 +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def svd_merge_lora( + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, + new_rank, + new_conv_rank, + device, +): + # Check for caption_text_input + if lora_a_model == '': + msgbox('Invalid model A file') + return + + if lora_b_model == '': + msgbox('Invalid model B file') + return + + # Check if source model exist + if not os.path.isfile(lora_a_model): + msgbox('The provided model A is not a file') + return + + if not os.path.isfile(lora_b_model): + msgbox('The provided model B is not a file') + return + + ratio_a = ratio + ratio_b = 1 - ratio + + run_cmd = f'{PYTHON} "{os.path.join("networks","svd_merge_lora.py")}"' + run_cmd += f' --save_precision {save_precision}' + run_cmd += f' --precision {precision}' + run_cmd += f' --save_to "{save_to}"' + run_cmd += f' --models "{lora_a_model}" "{lora_b_model}"' + run_cmd += f' --ratios {ratio_a} {ratio_b}' + run_cmd += f' --device {device}' + run_cmd += f' --new_rank "{new_rank}"' + run_cmd += f' --new_conv_rank "{new_conv_rank}"' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + +### +# Gradio UI +### + + +def gradio_svd_merge_lora_tab(): + with gr.Tab('Merge LoRA (SVD)'): + gr.Markdown('This utility can merge two LoRA networks together.') + + lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + + with gr.Row(): + lora_a_model = gr.Textbox( + label='LoRA model "A"', + placeholder='Path to the LoRA A model', + interactive=True, + ) + button_lora_a_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_a_model_file.click( + get_file_path, + inputs=[lora_a_model, lora_ext, lora_ext_name], + outputs=lora_a_model, + show_progress=False, + ) + + lora_b_model = gr.Textbox( + label='LoRA model "B"', + placeholder='Path to the LoRA B model', + interactive=True, + ) + button_lora_b_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_b_model_file.click( + get_file_path, + inputs=[lora_b_model, lora_ext, lora_ext_name], + outputs=lora_b_model, + show_progress=False, + ) + with gr.Row(): + ratio = gr.Slider( + label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B', + minimum=0, + maximum=1, + step=0.01, + value=0.5, + interactive=True, + ) + new_rank = gr.Slider( + label='New Rank', + minimum=1, + maximum=1024, + step=1, + value=128, + interactive=True, + ) + new_conv_rank = gr.Slider( + label='New Conv Rank', + minimum=1, + maximum=1024, + step=1, + value=128, + interactive=True, + ) + + with gr.Row(): + save_to = gr.Textbox( + label='Save to', + placeholder='path for the file to save...', + interactive=True, + ) + button_save_to = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_save_to.click( + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, + show_progress=False, + ) + precision = gr.Dropdown( + label='Merge precision', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + save_precision = gr.Dropdown( + label='Save precision', + choices=['fp16', 'bf16', 'float'], + value='float', + interactive=True, + ) + device = gr.Dropdown( + label='Device', + choices=['cpu', 'cuda',], + value='cuda', + interactive=True, + ) + + convert_button = gr.Button('Merge model') + + convert_button.click( + svd_merge_lora, + inputs=[ + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, + new_rank, + new_conv_rank, + device, + ], + show_progress=False, + ) diff --git a/library/train_util.py b/library/train_util.py index 6af1abe..718fe36 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -912,10 +912,14 @@ class FineTuningDataset(BaseDataset): if os.path.exists(image_key): abs_path = image_key else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(subset.image_dir, image_key) - assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" - abs_path = abs_path[0] + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + else: + # わりといい加減だがいい方法が思いつかん + abs_path = glob_images(subset.image_dir, image_key) + assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" + abs_path = abs_path[0] caption = img_md.get('caption') tags = img_md.get('tags') @@ -1757,15 +1761,22 @@ def get_optimizer(args, trainable_params): raise ImportError("No dadaptation / dadaptation がインストールされていないようです") print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") - min_lr = lr + actual_lr = lr + lr_count = 1 if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) for group in trainable_params: - min_lr = min(min_lr, group.get("lr", lr)) + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) - if min_lr <= 0.1: + if actual_lr <= 0.1: print( - f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}') + f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}') print('recommend option: lr=1.0 / 推奨は1.0です') + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}") optimizer_class = dadaptation.DAdaptAdam optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -2296,6 +2307,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v with torch.no_grad(): with accelerator.autocast(): for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue prompt = prompt.strip() if len(prompt) == 0 or prompt[0] == '#': continue @@ -2355,6 +2368,12 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v height = max(64, height - height % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8 + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) diff --git a/lora_gui.py b/lora_gui.py index d175b30..49918de 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -38,6 +38,7 @@ from library.tensorboard_gui import ( from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab from library.merge_lora_gui import gradio_merge_lora_tab +from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab from library.verify_lora_gui import gradio_verify_lora_tab from library.resize_lora_gui import gradio_resize_lora_tab from library.sampler_gui import sample_gradio_config, run_cmd_sample @@ -879,6 +880,7 @@ def lora_tab( ) gradio_dataset_balancing_tab() gradio_merge_lora_tab() + gradio_svd_merge_lora_tab() gradio_resize_lora_tab() gradio_verify_lora_tab() diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 9f40978..28b905f 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -103,7 +103,8 @@ def svd(args): if args.device: mat = mat.to(args.device) - # print(mat.size(), mat.device, rank, in_dim, out_dim) + + # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim if conv2d: @@ -112,7 +113,7 @@ def svd(args): else: mat = mat.squeeze() - U, S, Vh = torch.linalg.svd(mat) + U, S, Vh = torch.linalg.svd(mat.to("cuda")) U = U[:, :rank] S = S[:rank] @@ -137,27 +138,17 @@ def svd(args): lora_weights[lora_name] = (U, Vh) # make state dict for LoRA - lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict - lora_sd = lora_network_o.state_dict() - print(f"LoRA has {len(lora_sd)} weights.") - - for key in list(lora_sd.keys()): - if "alpha" in key: - continue - - lora_name = key.split('.')[0] - i = 0 if "lora_up" in key else 1 - - weights = lora_weights[lora_name][i] - # print(key, i, weights.size(), lora_sd[key].size()) - # if len(lora_sd[key].size()) == 4: - # weights = weights.unsqueeze(2).unsqueeze(3) - - assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" - lora_sd[key] = weights + lora_sd = {} + for lora_name, (up_weight, down_weight) in lora_weights.items(): + lora_sd[lora_name + '.lora_up.weight'] = up_weight + lora_sd[lora_name + '.lora_down.weight'] = down_weight + lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) # load state dict to LoRA and save it - info = lora_network_o.load_state_dict(lora_sd) + lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict + + info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") dir_name = os.path.dirname(args.save_to) @@ -167,7 +158,7 @@ def svd(args): # minimum metadata metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} - lora_network_o.save_weights(args.save_to, save_dtype, metadata) + lora_network_save.save_weights(args.save_to, save_dtype, metadata) print(f"LoRA weights are saved to: {args.save_to}") diff --git a/networks/lora.py b/networks/lora.py index c0181c0..6d3875d 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -21,30 +21,34 @@ class LoRAModule(torch.nn.Module): """ if alpha == 0 or None, alpha is rank (no scaling). """ super().__init__() self.lora_name = lora_name - self.lora_dim = lora_dim if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features - self.lora_dim = min(self.lora_dim, in_dim, out_dim) - if self.lora_dim != lora_dim: - print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + if org_module.__class__.__name__ == 'Conv2d': kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: - in_dim = org_module.in_features - out_dim = org_module.out_features - self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) - self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = lora_dim if alpha is None or alpha == 0 else alpha + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える @@ -149,12 +153,13 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un return network -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file, safe_open - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location='cpu') +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import load_file, safe_open + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location='cpu') # get dim/alpha mapping modules_dim = {} @@ -174,7 +179,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa # support old LoRA without alpha for key in modules_dim.keys(): if key not in modules_alpha: - modules_alpha = modules_dim[key] + modules_alpha = modules_dim[key] network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) network.weights_sd = weights_sd @@ -183,7 +188,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa class LoRANetwork(torch.nn.Module): # is it possible to apply conv_in and conv_out? - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' @@ -245,7 +251,12 @@ class LoRANetwork(torch.nn.Module): text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") self.weights_sd = None @@ -371,7 +382,7 @@ class LoRANetwork(torch.nn.Module): else: torch.save(state_dict, file) - @staticmethod + @ staticmethod def set_regions(networks, image): image = image.astype(np.float32) / 255.0 for i, network in enumerate(networks[:3]): diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 271de8e..09a19c1 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -1,14 +1,15 @@ # Convert LoRA to different rank approximation (should only be used to go to lower rank) # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py -# Thanks to cloneofsimo and kohya +# Thanks to cloneofsimo import argparse -import os import torch from safetensors.torch import load_file, save_file, safe_open from tqdm import tqdm from library import train_util, model_util +import numpy as np +MIN_SV = 1e-6 def load_state_dict(file_name, dtype): if model_util.is_safetensors(file_name): @@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) -def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0)/original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + if index >= len(S): + index = len(S) - 1 + + return index + + +def index_sv_fro(S, target): + S_squared = S.pow(2) + s_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + if index >= len(S): + index = len(S) - 1 + + return index + + +# Modified from Kohaku-blueleaf's extract/merge functions +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size, kernel_size, _ = weight.size() + U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() + del U, S, Vh, weight + return param_dict + + +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size = weight.size() + + U, S, Vh = torch.linalg.svd(weight.to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() + del U, S, Vh, weight + return param_dict + + +def merge_conv(lora_down, lora_up, device): + in_rank, in_size, kernel_size, k_ = lora_down.shape + out_size, out_rank, _, _ = lora_up.shape + assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) + weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) + del lora_up, lora_down + return weight + + +def merge_linear(lora_down, lora_up, device): + in_rank, in_size = lora_down.shape + out_size, out_rank = lora_up.shape + assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + weight = lora_up @ lora_down + del lora_up, lora_down + return weight + + +def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): + param_dict = {} + + if dynamic_method=="sv_ratio": + # Calculate new dim and alpha based off ratio + max_sv = S[0] + min_sv = max_sv/dynamic_param + new_rank = max(torch.sum(S > min_sv).item(),1) + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_cumulative": + # Calculate new dim and alpha based off cumulative sum + new_rank = index_sv_cumulative(S, dynamic_param) + new_rank = max(new_rank, 1) + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_fro": + # Calculate new dim and alpha based off sqrt sum of squares + new_rank = index_sv_fro(S, dynamic_param) + new_rank = min(max(new_rank, 1), len(S)-1) + new_alpha = float(scale*new_rank) + else: + new_rank = rank + new_alpha = float(scale*new_rank) + + + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + new_rank = 1 + new_alpha = float(scale*new_rank) + elif new_rank > rank: # cap max rank at rank + new_rank = rank + new_alpha = float(scale*new_rank) + + + # Calculate resize info + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) + fro_percent = float(s_red_fro/s_fro) + + param_dict["new_rank"] = new_rank + param_dict["new_alpha"] = new_alpha + param_dict["sum_retained"] = (s_rank)/s_sum + param_dict["fro_retained"] = fro_percent + param_dict["max_ratio"] = S[0]/S[new_rank] + + return param_dict + + +def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): network_alpha = None network_dim = None verbose_str = "\n" - - CLAMP_QUANTILE = 0.99 + fro_list = [] # Extract loaded lora dim and alpha for key, value in lora_sd.items(): @@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): network_alpha = network_dim scale = network_alpha/network_dim - new_alpha = float(scale*new_rank) # calculate new alpha from scale - print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}") + if dynamic_method: + print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") lora_down_weight = None lora_up_weight = None @@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): block_down_name = None block_up_name = None - print("resizing lora...") with torch.no_grad(): for key, value in tqdm(lora_sd.items()): if 'lora_down' in key: @@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): conv2d = (len(lora_down_weight.size()) == 4) if conv2d: - lora_down_weight = lora_down_weight.squeeze() - lora_up_weight = lora_up_weight.squeeze() - - if device: - org_device = lora_up_weight.device - lora_up_weight = lora_up_weight.to(args.device) - lora_down_weight = lora_down_weight.to(args.device) - - full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) - - U, S, Vh = torch.linalg.svd(full_weight_matrix) + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) if verbose: - s_sum = torch.sum(torch.abs(S)) - s_rank = torch.sum(torch.abs(S[:new_rank])) - verbose_str+=f"{block_down_name:76} | " - verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n" + max_ratio = param_dict['max_ratio'] + sum_retained = param_dict['sum_retained'] + fro_retained = param_dict['fro_retained'] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) - U = U[:, :new_rank] - S = S[:new_rank] - U = U @ torch.diag(S) + verbose_str+=f"{block_down_name:75} | " + verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" - Vh = Vh[:new_rank, :] + if verbose and dynamic_method: + verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str+=f"\n" - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val - - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) - - if conv2d: - U = U.unsqueeze(2).unsqueeze(3) - Vh = Vh.unsqueeze(2).unsqueeze(3) - - if device: - U = U.to(org_device) - Vh = Vh.to(org_device) - - o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype) + new_alpha = param_dict['new_alpha'] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) block_down_name = None block_up_name = None lora_down_weight = None lora_up_weight = None weights_loaded = False + del param_dict if verbose: print(verbose_str) + + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") print("resizing complete") return o_lora_sd, network_dim, new_alpha @@ -151,6 +274,9 @@ def resize(args): return torch.bfloat16 return None + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 save_dtype = str_to_dtype(args.save_precision) if save_dtype is None: @@ -159,17 +285,23 @@ def resize(args): print("loading Model...") lora_sd, metadata = load_state_dict(args.model, merge_dtype) - print("resizing rank...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose) + print("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) # update metadata if metadata is None: metadata = {} comment = metadata.get("ss_training_comment", "") - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) + + if not args.dynamic_method: + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + metadata["ss_network_dim"] = 'Dynamic' + metadata["ss_network_alpha"] = 'Dynamic' model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -193,6 +325,11 @@ if __name__ == '__main__': parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する") + parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") + parser.add_argument("--dynamic_param", type=float, default=None, + help="Specify target for dynamic reduction") + args = parser.parse_args() resize(args) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index c8e39b8..d907b43 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype): return sd -def save_to_file(file_name, model, state_dict, dtype): +def save_to_file(file_name, state_dict, dtype): if dtype is not None: for key in list(state_dict.keys()): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) + save_file(state_dict, file_name) else: - torch.save(model, file_name) + torch.save(state_dict, file_name) def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): @@ -76,7 +76,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty down_weight = down_weight.to(device) # W <- W + U * D - scale = (alpha / network_dim) + scale = (alpha / network_dim).to(device) if not conv2d: # linear weight = weight + ratio * (up_weight @ down_weight) * scale elif kernel_size == (1, 1): @@ -105,6 +105,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty mat = mat.squeeze() module_new_rank = new_conv_rank if conv2d_3x3 else new_rank + module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim U, S, Vh = torch.linalg.svd(mat) @@ -114,12 +115,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty Vh = Vh[:module_new_rank, :] - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val + # dist = torch.cat([U.flatten(), Vh.flatten()]) + # hi_val = torch.quantile(dist, CLAMP_QUANTILE) + # low_val = -hi_val - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) + # U = U.clamp(low_val, hi_val) + # Vh = Vh.clamp(low_val, hi_val) if conv2d: U = U.reshape(out_dim, module_new_rank, 1, 1) @@ -156,7 +157,7 @@ def merge(args): state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype) + save_to_file(args.save_to, state_dict, save_dtype) if __name__ == '__main__': diff --git a/train_README-ja.md b/train_README-ja.md index 479f960..d5f1b5f 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -502,6 +502,14 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。 +- `--persistent_data_loader_workers` + + Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。 + +- `--max_data_loader_n_workers` + + データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。 + - `--logging_dir` / `--log_prefix` 学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。 diff --git a/train_network.py b/train_network.py index cf64c89..5aa8af4 100644 --- a/train_network.py +++ b/train_network.py @@ -106,6 +106,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -134,6 +135,8 @@ def train(args): gc.collect() # prepare network + import sys + sys.path.append(os.path.dirname(__file__)) print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module) @@ -175,12 +178,13 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) + if is_main_process: + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps, num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする @@ -251,15 +255,17 @@ def train(args): # 学習する # TODO: find a way to handle total batch size when there are multiple datasets total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + if is_main_process: + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") # TODO refactor metadata creation and move to util metadata = { @@ -471,7 +477,8 @@ def train(args): loss_list = [] loss_total = 0.0 for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") + if is_main_process: + print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) metadata["ss_epoch"] = str(epoch+1) @@ -583,9 +590,10 @@ def train(args): print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + if is_main_process: + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -594,7 +602,6 @@ def train(args): metadata["ss_epoch"] = str(num_train_epochs) metadata["ss_training_finished_at"] = str(time.time()) - is_main_process = accelerator.is_main_process if is_main_process: network = unwrap_model(network) diff --git a/train_network_README-ja.md b/train_network_README-ja.md index 4a79a6f..79d1709 100644 --- a/train_network_README-ja.md +++ b/train_network_README-ja.md @@ -64,6 +64,10 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py * LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。 * `--network_alpha` * アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。 +* `--persistent_data_loader_workers` + * Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。 +* `--max_data_loader_n_workers` + * データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。 * `--network_weights` * 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。 * `--network_train_unet_only` From d1962d72400237b42d69312a989e59ba9b5e7508 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Fri, 10 Mar 2023 11:49:05 -0500 Subject: [PATCH 2/3] Switch to networks version of resize lora --- library/resize_lora_gui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index 6b4396b..527ff67 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -48,7 +48,7 @@ def resize_lora( if device == '': device = 'cuda' - run_cmd = f'{PYTHON} "{os.path.join("tools","resize_lora.py")}"' + run_cmd = f'{PYTHON} "{os.path.join("networks","resize_lora.py")}"' run_cmd += f' --save_precision {save_precision}' run_cmd += f' --save_to {save_to}' run_cmd += f' --model {model}' From a65555ea67c8e1519977cb91bfd9ba648350ee51 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Fri, 10 Mar 2023 20:05:38 -0500 Subject: [PATCH 3/3] Add support to load a config without opening the UI to get the file name --- dreambooth_gui.py | 19 ++++++++++++++++--- finetune_gui.py | 31 ++++++++++++++++++++++--------- library/common_gui.py | 3 +++ lora_gui.py | 19 ++++++++++++++++--- textual_inversion_gui.py | 19 ++++++++++++++++--- 5 files changed, 73 insertions(+), 18 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index df40784..dee017c 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -152,6 +152,7 @@ def save_configuration( def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -213,9 +214,13 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -231,7 +236,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) return tuple(values) @@ -506,6 +511,7 @@ def dreambooth_tab( button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -775,7 +781,14 @@ def dreambooth_tab( button_open_config.click( open_configuration, - inputs=[config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) diff --git a/finetune_gui.py b/finetune_gui.py index 3ef1cbd..59dffd8 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -149,7 +149,8 @@ def save_configuration( return file_path -def open_config_file( +def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -217,9 +218,13 @@ def open_config_file( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -235,7 +240,7 @@ def open_config_file( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) return tuple(values) @@ -492,8 +497,8 @@ def remove_doublequote(file_path): def finetune_tab(): - dummy_ft_true = gr.Label(value=True, visible=False) - dummy_ft_false = gr.Label(value=False, visible=False) + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) gr.Markdown('Train a custom model using kohya finetune python code...') ( @@ -501,6 +506,7 @@ def finetune_tab(): button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -770,22 +776,29 @@ def finetune_tab(): button_run.click(train_model, inputs=settings_list) button_open_config.click( - open_config_file, - inputs=[config_file_name] + settings_list, + open_configuration, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_save_config.click( save_configuration, - inputs=[dummy_ft_false, config_file_name] + settings_list, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name], show_progress=False, ) button_save_as_config.click( save_configuration, - inputs=[dummy_ft_true, config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name], show_progress=False, ) diff --git a/library/common_gui.py b/library/common_gui.py index b22594f..e200141 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -405,11 +405,14 @@ def gradio_config(): placeholder="type the configuration file path or use the 'Open' button above to select it...", interactive=True, ) + button_load_config = gr.Button('Load 💾', elem_id='open_folder') + config_file_name.change(remove_doublequote, inputs=[config_file_name], outputs=[config_file_name]) return ( button_open_config, button_save_config, button_save_as_config, config_file_name, + button_load_config, ) diff --git a/lora_gui.py b/lora_gui.py index 49918de..23da712 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -168,6 +168,7 @@ def save_configuration( def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -239,9 +240,13 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -257,7 +262,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) # This next section is about making the LoCon parameters visible if LoRA_type = 'Standard' @@ -610,6 +615,7 @@ def lora_tab( button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -974,7 +980,14 @@ def lora_tab( button_open_config.click( open_configuration, - inputs=[config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list + [LoCon_row], + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, ) diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index c92bdc0..ed3c33a 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -158,6 +158,7 @@ def save_configuration( def open_configuration( + ask_for_file, file_path, pretrained_model_name_or_path, v2, @@ -225,9 +226,13 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - file_path = get_file_path(file_path) + + if ask_for_file: + file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file @@ -243,7 +248,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['file_path']: + if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) return tuple(values) @@ -548,6 +553,7 @@ def ti_tab( button_save_config, button_save_as_config, config_file_name, + button_load_config, ) = gradio_config() ( @@ -865,7 +871,14 @@ def ti_tab( button_open_config.click( open_configuration, - inputs=[config_file_name] + settings_list, + inputs=[dummy_db_true, config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + show_progress=False, + ) + + button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, )