From 2626214f8a7f02b0c1be21b6a41f65f6eea3edf6 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sat, 4 Feb 2023 11:55:06 -0500 Subject: [PATCH] Add support for LoRA resizing --- README.md | 8 +- dreambooth_gui.py | 6 ++ finetune_gui.py | 6 ++ library/common_gui.py | 45 +++++----- library/resize_lora_gui.py | 104 +++++++++++++++++++++++ library/train_util.py | 5 +- lora_gui.py | 8 ++ networks/resize_lora.py | 166 +++++++++++++++++++++++++++++++++++++ textual_inversion_gui.py | 6 ++ train_db.py | 4 +- train_network.py | 2 +- 11 files changed, 337 insertions(+), 23 deletions(-) create mode 100644 library/resize_lora_gui.py create mode 100644 networks/resize_lora.py diff --git a/README.md b/README.md index 254b0c2..fabfcd0 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,13 @@ Then redo the installation instruction within the kohya_ss venv. ## Change history -* 2023/02/03 +* 2023/02/04 (v20.6.1) + - ``--persistent_data_loader_workers`` option is added to ``fine_tune.py``, ``train_db.py`` and ``train_network.py``. This option may significantly reduce the waiting time between epochs. Thanks to hitomi! + - ``--debug_dataset`` option is now working on non-Windows environment. Thanks to tsukimiya! + - ``networks/resize_lora.py`` script is added. This can approximate the higher-rank (dim) LoRA model by a lower-rank LoRA model, e.g. 128 by 4. Thanks to mgz-dev! + - ``--help`` option shows usage. + - Currently the metadata is not copied. This will be fixed in the near future. +* 2023/02/03 (v20.6.0) - Increase max LoRA rank (dim) size to 1024. - Update finetune preprocessing scripts. - ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev! diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 24ccfe7..f64480b 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -83,6 +83,7 @@ def save_configuration( mem_eff_attn, gradient_accumulation_steps, model_list, keep_tokens, + persistent_data_loader_workers, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -167,6 +168,7 @@ def open_configuration( mem_eff_attn, gradient_accumulation_steps, model_list, keep_tokens, + persistent_data_loader_workers, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -236,6 +238,7 @@ def train_model( gradient_accumulation_steps, model_list, # Keep this. Yes, it is unused here but required given the common list used keep_tokens, + persistent_data_loader_workers, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -398,6 +401,7 @@ def train_model( xformers=xformers, use_8bit_adam=use_8bit_adam, keep_tokens=keep_tokens, + persistent_data_loader_workers=persistent_data_loader_workers, ) print(run_cmd) @@ -605,6 +609,7 @@ def dreambooth_tab( max_train_epochs, max_data_loader_n_workers, keep_tokens, + persistent_data_loader_workers, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -669,6 +674,7 @@ def dreambooth_tab( gradient_accumulation_steps, model_list, keep_tokens, + persistent_data_loader_workers, ] button_open_config.click( diff --git a/finetune_gui.py b/finetune_gui.py index 981be39..49dcd52 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -79,6 +79,7 @@ def save_configuration( model_list, cache_latents, use_latent_files, keep_tokens, + persistent_data_loader_workers, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -169,6 +170,7 @@ def open_config_file( model_list, cache_latents, use_latent_files, keep_tokens, + persistent_data_loader_workers, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -244,6 +246,7 @@ def train_model( model_list, # Keep this. Yes, it is unused here but required given the common list used cache_latents, use_latent_files, keep_tokens, + persistent_data_loader_workers, ): # create caption json file if generate_caption_database: @@ -382,6 +385,7 @@ def train_model( xformers=xformers, use_8bit_adam=use_8bit_adam, keep_tokens=keep_tokens, + persistent_data_loader_workers=persistent_data_loader_workers, ) print(run_cmd) @@ -587,6 +591,7 @@ def finetune_tab(): max_train_epochs, max_data_loader_n_workers, keep_tokens, + persistent_data_loader_workers, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -647,6 +652,7 @@ def finetune_tab(): cache_latents, use_latent_files, keep_tokens, + persistent_data_loader_workers, ] button_run.click(train_model, inputs=settings_list) diff --git a/library/common_gui.py b/library/common_gui.py index 2595540..4634ce2 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -510,31 +510,12 @@ def run_cmd_training(**kwargs): def gradio_advanced_training(): with gr.Row(): - full_fp16 = gr.Checkbox( - label='Full fp16 training (experimental)', value=False - ) - gradient_checkpointing = gr.Checkbox( - label='Gradient checkpointing', value=False - ) - shuffle_caption = gr.Checkbox( - label='Shuffle caption', value=False - ) keep_tokens = gr.Slider( label='Keep n tokens', value='0', minimum=0, maximum=32, step=1 ) - use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) - xformers = gr.Checkbox(label='Use xformers', value=True) - with gr.Row(): - color_aug = gr.Checkbox( - label='Color augmentation', value=False - ) - flip_aug = gr.Checkbox(label='Flip augmentation', value=False) clip_skip = gr.Slider( label='Clip skip', value='1', minimum=1, maximum=12, step=1 ) - mem_eff_attn = gr.Checkbox( - label='Memory efficient attention', value=False - ) max_token_length = gr.Dropdown( label='Max Token Length', choices=[ @@ -544,6 +525,29 @@ def gradio_advanced_training(): ], value='75', ) + full_fp16 = gr.Checkbox( + label='Full fp16 training (experimental)', value=False + ) + with gr.Row(): + gradient_checkpointing = gr.Checkbox( + label='Gradient checkpointing', value=False + ) + shuffle_caption = gr.Checkbox( + label='Shuffle caption', value=False + ) + persistent_data_loader_workers = gr.Checkbox( + label='Persistent data loader', value=False + ) + mem_eff_attn = gr.Checkbox( + label='Memory efficient attention', value=False + ) + with gr.Row(): + use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) + xformers = gr.Checkbox(label='Use xformers', value=True) + color_aug = gr.Checkbox( + label='Color augmentation', value=False + ) + flip_aug = gr.Checkbox(label='Flip augmentation', value=False) with gr.Row(): save_state = gr.Checkbox(label='Save training state', value=False) resume = gr.Textbox( @@ -576,6 +580,7 @@ def gradio_advanced_training(): max_train_epochs, max_data_loader_n_workers, keep_tokens, + persistent_data_loader_workers, ) def run_cmd_advanced_training(**kwargs): @@ -622,6 +627,8 @@ def run_cmd_advanced_training(**kwargs): ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '', + ' --persistent_data_loader_workers' if kwargs.get('persistent_data_loader_workers') else '', + ] run_cmd = ''.join(options) return run_cmd diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py new file mode 100644 index 0000000..1ce6ebf --- /dev/null +++ b/library/resize_lora_gui.py @@ -0,0 +1,104 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import get_saveasfilename_path, get_file_path + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +document_symbol = '\U0001F4C4' # 📄 + + +def resize_lora( + model, new_rank, save_to, save_precision, device, +): + # Check for caption_text_input + if model == '': + msgbox('Invalid model file') + return + + # Check if source model exist + if not os.path.isfile(model): + msgbox('The provided model is not a file') + return + + if device == '': + device = 'cuda' + + run_cmd = f'.\\venv\Scripts\python.exe "networks\\resize_lora.py"' + run_cmd += f' --save_precision {save_precision}' + run_cmd += f' --save_to {save_to}' + run_cmd += f' --model {model}' + run_cmd += f' --new_rank {new_rank}' + run_cmd += f' --device {device}' + + print(run_cmd) + + # Run the command + subprocess.run(run_cmd) + + +### +# Gradio UI +### + + +def gradio_resize_lora_tab(): + with gr.Tab('Resize LoRA'): + gr.Markdown( + 'This utility can resize a LoRA.' + ) + + lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + + with gr.Row(): + model = gr.Textbox( + label='Source LoRA', + placeholder='Path to the LoRA to resize', + interactive=True, + ) + button_lora_a_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_a_model_file.click( + get_file_path, + inputs=[model, lora_ext, lora_ext_name], + outputs=model, + ) + with gr.Row(): + new_rank = gr.Slider(label="Desired LoRA rank", minimum=1, maximum=1024, step=1, value=4, + interactive=True,) + + with gr.Row(): + save_to = gr.Textbox( + label='Save to', + placeholder='path for the LoRA file to save...', + interactive=True, + ) + button_save_to = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_save_to.click( + get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to + ) + save_precision = gr.Dropdown( + label='Save precison', + choices=['fp16', 'bf16', 'float'], + value='fp16', + interactive=True, + ) + device = gr.Textbox( + label='Device', + placeholder='{Optional) device to use, cuda for GPU. Default: cuda', + interactive=True, + ) + + convert_button = gr.Button('Resize model') + + convert_button.click( + resize_lora, + inputs=[model, new_rank, save_to, save_precision, device, + ], + ) diff --git a/library/train_util.py b/library/train_util.py index 86508a3..ea261be 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -772,7 +772,8 @@ def debug_dataset(train_dataset, show_input_ids=False): im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - cv2.imshow("img", im) + if os.name == 'nt': # only windows + cv2.imshow("img", im) k = cv2.waitKey() cv2.destroyAllWindows() if k == 27: @@ -1194,6 +1195,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") + parser.add_argument("--persistent_data_loader_workers", action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする") diff --git a/lora_gui.py b/lora_gui.py index de7c36c..3e44683 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -32,6 +32,7 @@ from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab from library.merge_lora_gui import gradio_merge_lora_tab from library.verify_lora_gui import gradio_verify_lora_tab +from library.resize_lora_gui import gradio_resize_lora_tab from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 @@ -92,6 +93,7 @@ def save_configuration( network_alpha, training_comment, keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + persistent_data_loader_workers, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -182,6 +184,7 @@ def open_configuration( network_alpha, training_comment, keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + persistent_data_loader_workers, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -256,6 +259,7 @@ def train_model( network_alpha, training_comment, keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + persistent_data_loader_workers, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -446,6 +450,7 @@ def train_model( xformers=xformers, use_8bit_adam=use_8bit_adam, keep_tokens=keep_tokens, + persistent_data_loader_workers=persistent_data_loader_workers, ) print(run_cmd) @@ -689,6 +694,7 @@ def lora_tab( max_train_epochs, max_data_loader_n_workers, keep_tokens, + persistent_data_loader_workers, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -708,6 +714,7 @@ def lora_tab( ) gradio_dataset_balancing_tab() gradio_merge_lora_tab() + gradio_resize_lora_tab() gradio_verify_lora_tab() @@ -764,6 +771,7 @@ def lora_tab( training_comment, keep_tokens, lr_scheduler_num_cycles, lr_scheduler_power, + persistent_data_loader_workers, ] button_open_config.click( diff --git a/networks/resize_lora.py b/networks/resize_lora.py new file mode 100644 index 0000000..e10d35b --- /dev/null +++ b/networks/resize_lora.py @@ -0,0 +1,166 @@ +# Convert LoRA to different rank approximation (should only be used to go to lower rank) +# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo and kohya + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == '.safetensors': + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location='cpu') + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + + +def resize_lora_model(model, new_rank, merge_dtype, save_dtype): + print("Loading Model...") + lora_sd = load_state_dict(model, merge_dtype) + + network_alpha = None + network_dim = None + + CLAMP_QUANTILE = 0.99 + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and 'alpha' in key: + network_alpha = value + if network_dim is None and 'lora_down' in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim + + scale = network_alpha/network_dim + new_alpha = float(scale*new_rank) # calculate new alpha from scale + + print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}") + + lora_down_weight = None + lora_up_weight = None + + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None + + print("resizing lora...") + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + if 'lora_down' in key: + block_down_name = key.split(".")[0] + lora_down_weight = value + if 'lora_up' in key: + block_up_name = key.split(".")[0] + lora_up_weight = value + + weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) + + if (block_down_name == block_up_name) and weights_loaded: + + conv2d = (len(lora_down_weight.size()) == 4) + + if conv2d: + lora_down_weight = lora_down_weight.squeeze() + lora_up_weight = lora_up_weight.squeeze() + + if args.device: + org_device = lora_up_weight.device + lora_up_weight = lora_up_weight.to(args.device) + lora_down_weight = lora_down_weight.to(args.device) + + full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) + + U, S, Vh = torch.linalg.svd(full_weight_matrix) + + U = U[:, :new_rank] + S = S[:new_rank] + U = U @ torch.diag(S) + + Vh = Vh[:new_rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + if conv2d: + U = U.unsqueeze(2).unsqueeze(3) + Vh = Vh.unsqueeze(2).unsqueeze(3) + + if args.device: + U = U.to(org_device) + Vh = Vh.to(org_device) + + o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype) + + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + + print("resizing complete") + return o_lora_sd + +def resize(args): + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat") + parser.add_argument("--new_rank", type=int, default=4, + help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--model", type=str, default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + + args = parser.parse_args() + resize(args) diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 0f21a8d..d35a78c 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -83,6 +83,7 @@ def save_configuration( mem_eff_attn, gradient_accumulation_steps, model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, + persistent_data_loader_workers, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -171,6 +172,7 @@ def open_configuration( mem_eff_attn, gradient_accumulation_steps, model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, + persistent_data_loader_workers, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -240,6 +242,7 @@ def train_model( gradient_accumulation_steps, model_list, # Keep this. Yes, it is unused here but required given the common list used token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, + persistent_data_loader_workers, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -417,6 +420,7 @@ def train_model( xformers=xformers, use_8bit_adam=use_8bit_adam, keep_tokens=keep_tokens, + persistent_data_loader_workers=persistent_data_loader_workers, ) run_cmd += f' --token_string="{token_string}"' run_cmd += f' --init_word="{init_word}"' @@ -671,6 +675,7 @@ def ti_tab( max_train_epochs, max_data_loader_n_workers, keep_tokens, + persistent_data_loader_workers, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -736,6 +741,7 @@ def ti_tab( model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, + persistent_data_loader_workers, ] button_open_config.click( diff --git a/train_db.py b/train_db.py index 8ac503e..bf25aae 100644 --- a/train_db.py +++ b/train_db.py @@ -133,7 +133,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -176,6 +176,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/train_network.py b/train_network.py index 8840522..6dd1a73 100644 --- a/train_network.py +++ b/train_network.py @@ -214,7 +214,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: