Add device support
This commit is contained in:
parent
1eef89c581
commit
f83638b8a6
@ -213,6 +213,7 @@ This will store your a backup file with your current locally installed pip packa
|
||||
- Support extensions with upper cases for images for not Windows environment.
|
||||
- Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
|
||||
- Fix issue: https://github.com/bmaltais/kohya_ss/issues/406
|
||||
- Add device support to LoRA extract.
|
||||
* 2023/03/19 (v21.2.5):
|
||||
- Fix basic captioning logic
|
||||
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1.
|
||||
|
@ -338,12 +338,21 @@ def train_model(
|
||||
if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.')
|
||||
]
|
||||
|
||||
# Check if subfolders are present. If not let the user know and return
|
||||
if not subfolders:
|
||||
print('\033[33mNo subfolders were found in', train_data_dir, ' can\'t train\...033[0m')
|
||||
return
|
||||
|
||||
total_steps = 0
|
||||
|
||||
# Loop through each subfolder and extract the number of repeats
|
||||
for folder in subfolders:
|
||||
# Extract the number of repeats from the folder name
|
||||
repeats = int(folder.split('_')[0])
|
||||
try:
|
||||
repeats = int(folder.split('_')[0])
|
||||
except ValueError:
|
||||
print('\033[33mSubfolder', folder, 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m')
|
||||
continue
|
||||
|
||||
# Count the number of images in the folder
|
||||
num_images = len(
|
||||
@ -356,13 +365,20 @@ def train_model(
|
||||
or f.endswith('.webp')
|
||||
]
|
||||
)
|
||||
|
||||
if num_images == 0:
|
||||
print(f'{folder} folder contain no images, skipping...')
|
||||
else:
|
||||
# Calculate the total number of steps for this folder
|
||||
steps = repeats * num_images
|
||||
total_steps += steps
|
||||
|
||||
# Calculate the total number of steps for this folder
|
||||
steps = repeats * num_images
|
||||
total_steps += steps
|
||||
# Print the result
|
||||
print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
|
||||
|
||||
# Print the result
|
||||
print(f'Folder {folder}: {steps} steps')
|
||||
if total_steps == 0:
|
||||
print('\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m')
|
||||
return
|
||||
|
||||
# Print the result
|
||||
# print(f"{total_steps} total steps")
|
||||
@ -370,9 +386,7 @@ def train_model(
|
||||
if reg_data_dir == '':
|
||||
reg_factor = 1
|
||||
else:
|
||||
print(
|
||||
'Regularisation images are used... Will double the number of steps required...'
|
||||
)
|
||||
print('\033[94mRegularisation images are used... Will double the number of steps required...\033[0m')
|
||||
reg_factor = 2
|
||||
|
||||
# calculate max_train_steps
|
||||
|
@ -932,7 +932,8 @@ def gradio_advanced_training():
|
||||
label='VAE batch size',
|
||||
minimum=0,
|
||||
maximum=32,
|
||||
value=0
|
||||
value=0,
|
||||
every=1
|
||||
)
|
||||
with gr.Row():
|
||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||
|
@ -23,6 +23,7 @@ def extract_lora(
|
||||
dim,
|
||||
v2,
|
||||
conv_dim,
|
||||
device,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model_tuned == '':
|
||||
@ -50,6 +51,7 @@ def extract_lora(
|
||||
run_cmd += f' --model_org "{model_org}"'
|
||||
run_cmd += f' --model_tuned "{model_tuned}"'
|
||||
run_cmd += f' --dim {dim}'
|
||||
run_cmd += f' --device {device}'
|
||||
if conv_dim > 0:
|
||||
run_cmd += f' --conv_dim {conv_dim}'
|
||||
if v2:
|
||||
@ -148,6 +150,15 @@ def gradio_extract_lora_tab():
|
||||
interactive=True,
|
||||
)
|
||||
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
|
||||
device = gr.Dropdown(
|
||||
label='Device',
|
||||
choices=[
|
||||
'cpu',
|
||||
'cuda',
|
||||
],
|
||||
value='cuda',
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
extract_button = gr.Button('Extract LoRA model')
|
||||
|
||||
@ -161,6 +172,7 @@ def gradio_extract_lora_tab():
|
||||
dim,
|
||||
v2,
|
||||
conv_dim,
|
||||
device
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
@ -11,8 +11,8 @@ import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
MIN_DIFF = 1e-6
|
||||
CLAMP_QUANTILE = 1
|
||||
MIN_DIFF = 1e-8
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
@ -121,12 +121,12 @@ def svd(args):
|
||||
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
# dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
# low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
# U = U.clamp(low_val, hi_val)
|
||||
# Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user