From f83638b8a65f7e7b9b9a612ff6d095f68310f8b2 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Wed, 22 Mar 2023 12:55:30 -0400 Subject: [PATCH] Add device support --- README.md | 1 + dreambooth_gui.py | 32 ++++++++++++++++++++-------- library/common_gui.py | 3 ++- library/extract_lora_gui.py | 12 +++++++++++ networks/extract_lora_from_models.py | 14 ++++++------ 5 files changed, 45 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index c198d09..caef52e 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,7 @@ This will store your a backup file with your current locally installed pip packa - Support extensions with upper cases for images for not Windows environment. - Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki! - Fix issue: https://github.com/bmaltais/kohya_ss/issues/406 + - Add device support to LoRA extract. * 2023/03/19 (v21.2.5): - Fix basic captioning logic - Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1. diff --git a/dreambooth_gui.py b/dreambooth_gui.py index c3458be..a72acff 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -338,12 +338,21 @@ def train_model( if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.') ] + # Check if subfolders are present. If not let the user know and return + if not subfolders: + print('\033[33mNo subfolders were found in', train_data_dir, ' can\'t train\...033[0m') + return + total_steps = 0 # Loop through each subfolder and extract the number of repeats for folder in subfolders: # Extract the number of repeats from the folder name - repeats = int(folder.split('_')[0]) + try: + repeats = int(folder.split('_')[0]) + except ValueError: + print('\033[33mSubfolder', folder, 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m') + continue # Count the number of images in the folder num_images = len( @@ -356,13 +365,20 @@ def train_model( or f.endswith('.webp') ] ) + + if num_images == 0: + print(f'{folder} folder contain no images, skipping...') + else: + # Calculate the total number of steps for this folder + steps = repeats * num_images + total_steps += steps - # Calculate the total number of steps for this folder - steps = repeats * num_images - total_steps += steps + # Print the result + print('\033[33mFolder', folder, ':', steps, 'steps\033[0m') - # Print the result - print(f'Folder {folder}: {steps} steps') + if total_steps == 0: + print('\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m') + return # Print the result # print(f"{total_steps} total steps") @@ -370,9 +386,7 @@ def train_model( if reg_data_dir == '': reg_factor = 1 else: - print( - 'Regularisation images are used... Will double the number of steps required...' - ) + print('\033[94mRegularisation images are used... Will double the number of steps required...\033[0m') reg_factor = 2 # calculate max_train_steps diff --git a/library/common_gui.py b/library/common_gui.py index 14c448c..0d37cc5 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -932,7 +932,8 @@ def gradio_advanced_training(): label='VAE batch size', minimum=0, maximum=32, - value=0 + value=0, + every=1 ) with gr.Row(): save_state = gr.Checkbox(label='Save training state', value=False) diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 6088805..5f48686 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -23,6 +23,7 @@ def extract_lora( dim, v2, conv_dim, + device, ): # Check for caption_text_input if model_tuned == '': @@ -50,6 +51,7 @@ def extract_lora( run_cmd += f' --model_org "{model_org}"' run_cmd += f' --model_tuned "{model_tuned}"' run_cmd += f' --dim {dim}' + run_cmd += f' --device {device}' if conv_dim > 0: run_cmd += f' --conv_dim {conv_dim}' if v2: @@ -148,6 +150,15 @@ def gradio_extract_lora_tab(): interactive=True, ) v2 = gr.Checkbox(label='v2', value=False, interactive=True) + device = gr.Dropdown( + label='Device', + choices=[ + 'cpu', + 'cuda', + ], + value='cuda', + interactive=True, + ) extract_button = gr.Button('Extract LoRA model') @@ -161,6 +172,7 @@ def gradio_extract_lora_tab(): dim, v2, conv_dim, + device ], show_progress=False, ) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 9aa2848..783fa1b 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -11,8 +11,8 @@ import library.model_util as model_util import lora -CLAMP_QUANTILE = 0.99 -MIN_DIFF = 1e-6 +CLAMP_QUANTILE = 1 +MIN_DIFF = 1e-8 def save_to_file(file_name, model, state_dict, dtype): @@ -121,12 +121,12 @@ def svd(args): Vh = Vh[:rank, :] - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val + # dist = torch.cat([U.flatten(), Vh.flatten()]) + # hi_val = torch.quantile(dist, CLAMP_QUANTILE) + # low_val = -hi_val - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) + # U = U.clamp(low_val, hi_val) + # Vh = Vh.clamp(low_val, hi_val) if conv2d: U = U.reshape(out_dim, rank, 1, 1)