Add device support

This commit is contained in:
bmaltais 2023-03-22 12:55:30 -04:00
parent 1eef89c581
commit f83638b8a6
5 changed files with 45 additions and 17 deletions

View File

@ -213,6 +213,7 @@ This will store your a backup file with your current locally installed pip packa
- Support extensions with upper cases for images for not Windows environment.
- Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
- Fix issue: https://github.com/bmaltais/kohya_ss/issues/406
- Add device support to LoRA extract.
* 2023/03/19 (v21.2.5):
- Fix basic captioning logic
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1.

View File

@ -338,12 +338,21 @@ def train_model(
if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.')
]
# Check if subfolders are present. If not let the user know and return
if not subfolders:
print('\033[33mNo subfolders were found in', train_data_dir, ' can\'t train\...033[0m')
return
total_steps = 0
# Loop through each subfolder and extract the number of repeats
for folder in subfolders:
# Extract the number of repeats from the folder name
try:
repeats = int(folder.split('_')[0])
except ValueError:
print('\033[33mSubfolder', folder, 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m')
continue
# Count the number of images in the folder
num_images = len(
@ -357,12 +366,19 @@ def train_model(
]
)
if num_images == 0:
print(f'{folder} folder contain no images, skipping...')
else:
# Calculate the total number of steps for this folder
steps = repeats * num_images
total_steps += steps
# Print the result
print(f'Folder {folder}: {steps} steps')
print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
if total_steps == 0:
print('\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m')
return
# Print the result
# print(f"{total_steps} total steps")
@ -370,9 +386,7 @@ def train_model(
if reg_data_dir == '':
reg_factor = 1
else:
print(
'Regularisation images are used... Will double the number of steps required...'
)
print('\033[94mRegularisation images are used... Will double the number of steps required...\033[0m')
reg_factor = 2
# calculate max_train_steps

View File

@ -932,7 +932,8 @@ def gradio_advanced_training():
label='VAE batch size',
minimum=0,
maximum=32,
value=0
value=0,
every=1
)
with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False)

View File

@ -23,6 +23,7 @@ def extract_lora(
dim,
v2,
conv_dim,
device,
):
# Check for caption_text_input
if model_tuned == '':
@ -50,6 +51,7 @@ def extract_lora(
run_cmd += f' --model_org "{model_org}"'
run_cmd += f' --model_tuned "{model_tuned}"'
run_cmd += f' --dim {dim}'
run_cmd += f' --device {device}'
if conv_dim > 0:
run_cmd += f' --conv_dim {conv_dim}'
if v2:
@ -148,6 +150,15 @@ def gradio_extract_lora_tab():
interactive=True,
)
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
device = gr.Dropdown(
label='Device',
choices=[
'cpu',
'cuda',
],
value='cuda',
interactive=True,
)
extract_button = gr.Button('Extract LoRA model')
@ -161,6 +172,7 @@ def gradio_extract_lora_tab():
dim,
v2,
conv_dim,
device
],
show_progress=False,
)

View File

@ -11,8 +11,8 @@ import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-6
CLAMP_QUANTILE = 1
MIN_DIFF = 1e-8
def save_to_file(file_name, model, state_dict, dtype):
@ -121,12 +121,12 @@ def svd(args):
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
# dist = torch.cat([U.flatten(), Vh.flatten()])
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
# low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
# U = U.clamp(low_val, hi_val)
# Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)