Add device support

This commit is contained in:
bmaltais 2023-03-22 12:55:30 -04:00
parent 1eef89c581
commit f83638b8a6
5 changed files with 45 additions and 17 deletions

View File

@ -213,6 +213,7 @@ This will store your a backup file with your current locally installed pip packa
- Support extensions with upper cases for images for not Windows environment. - Support extensions with upper cases for images for not Windows environment.
- Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki! - Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
- Fix issue: https://github.com/bmaltais/kohya_ss/issues/406 - Fix issue: https://github.com/bmaltais/kohya_ss/issues/406
- Add device support to LoRA extract.
* 2023/03/19 (v21.2.5): * 2023/03/19 (v21.2.5):
- Fix basic captioning logic - Fix basic captioning logic
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1. - Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1.

View File

@ -338,12 +338,21 @@ def train_model(
if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.') if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.')
] ]
# Check if subfolders are present. If not let the user know and return
if not subfolders:
print('\033[33mNo subfolders were found in', train_data_dir, ' can\'t train\...033[0m')
return
total_steps = 0 total_steps = 0
# Loop through each subfolder and extract the number of repeats # Loop through each subfolder and extract the number of repeats
for folder in subfolders: for folder in subfolders:
# Extract the number of repeats from the folder name # Extract the number of repeats from the folder name
repeats = int(folder.split('_')[0]) try:
repeats = int(folder.split('_')[0])
except ValueError:
print('\033[33mSubfolder', folder, 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m')
continue
# Count the number of images in the folder # Count the number of images in the folder
num_images = len( num_images = len(
@ -356,13 +365,20 @@ def train_model(
or f.endswith('.webp') or f.endswith('.webp')
] ]
) )
if num_images == 0:
print(f'{folder} folder contain no images, skipping...')
else:
# Calculate the total number of steps for this folder
steps = repeats * num_images
total_steps += steps
# Calculate the total number of steps for this folder # Print the result
steps = repeats * num_images print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
total_steps += steps
# Print the result if total_steps == 0:
print(f'Folder {folder}: {steps} steps') print('\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m')
return
# Print the result # Print the result
# print(f"{total_steps} total steps") # print(f"{total_steps} total steps")
@ -370,9 +386,7 @@ def train_model(
if reg_data_dir == '': if reg_data_dir == '':
reg_factor = 1 reg_factor = 1
else: else:
print( print('\033[94mRegularisation images are used... Will double the number of steps required...\033[0m')
'Regularisation images are used... Will double the number of steps required...'
)
reg_factor = 2 reg_factor = 2
# calculate max_train_steps # calculate max_train_steps

View File

@ -932,7 +932,8 @@ def gradio_advanced_training():
label='VAE batch size', label='VAE batch size',
minimum=0, minimum=0,
maximum=32, maximum=32,
value=0 value=0,
every=1
) )
with gr.Row(): with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False) save_state = gr.Checkbox(label='Save training state', value=False)

View File

@ -23,6 +23,7 @@ def extract_lora(
dim, dim,
v2, v2,
conv_dim, conv_dim,
device,
): ):
# Check for caption_text_input # Check for caption_text_input
if model_tuned == '': if model_tuned == '':
@ -50,6 +51,7 @@ def extract_lora(
run_cmd += f' --model_org "{model_org}"' run_cmd += f' --model_org "{model_org}"'
run_cmd += f' --model_tuned "{model_tuned}"' run_cmd += f' --model_tuned "{model_tuned}"'
run_cmd += f' --dim {dim}' run_cmd += f' --dim {dim}'
run_cmd += f' --device {device}'
if conv_dim > 0: if conv_dim > 0:
run_cmd += f' --conv_dim {conv_dim}' run_cmd += f' --conv_dim {conv_dim}'
if v2: if v2:
@ -148,6 +150,15 @@ def gradio_extract_lora_tab():
interactive=True, interactive=True,
) )
v2 = gr.Checkbox(label='v2', value=False, interactive=True) v2 = gr.Checkbox(label='v2', value=False, interactive=True)
device = gr.Dropdown(
label='Device',
choices=[
'cpu',
'cuda',
],
value='cuda',
interactive=True,
)
extract_button = gr.Button('Extract LoRA model') extract_button = gr.Button('Extract LoRA model')
@ -161,6 +172,7 @@ def gradio_extract_lora_tab():
dim, dim,
v2, v2,
conv_dim, conv_dim,
device
], ],
show_progress=False, show_progress=False,
) )

View File

@ -11,8 +11,8 @@ import library.model_util as model_util
import lora import lora
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 1
MIN_DIFF = 1e-6 MIN_DIFF = 1e-8
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, model, state_dict, dtype):
@ -121,12 +121,12 @@ def svd(args):
Vh = Vh[:rank, :] Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()]) # dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE) # hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val # low_val = -hi_val
U = U.clamp(low_val, hi_val) # U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val) # Vh = Vh.clamp(low_val, hi_val)
if conv2d: if conv2d:
U = U.reshape(out_dim, rank, 1, 1) U = U.reshape(out_dim, rank, 1, 1)