Add device support
This commit is contained in:
parent
1eef89c581
commit
f83638b8a6
@ -213,6 +213,7 @@ This will store your a backup file with your current locally installed pip packa
|
|||||||
- Support extensions with upper cases for images for not Windows environment.
|
- Support extensions with upper cases for images for not Windows environment.
|
||||||
- Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
|
- Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
|
||||||
- Fix issue: https://github.com/bmaltais/kohya_ss/issues/406
|
- Fix issue: https://github.com/bmaltais/kohya_ss/issues/406
|
||||||
|
- Add device support to LoRA extract.
|
||||||
* 2023/03/19 (v21.2.5):
|
* 2023/03/19 (v21.2.5):
|
||||||
- Fix basic captioning logic
|
- Fix basic captioning logic
|
||||||
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1.
|
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1.
|
||||||
|
@ -338,12 +338,21 @@ def train_model(
|
|||||||
if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.')
|
if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.')
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Check if subfolders are present. If not let the user know and return
|
||||||
|
if not subfolders:
|
||||||
|
print('\033[33mNo subfolders were found in', train_data_dir, ' can\'t train\...033[0m')
|
||||||
|
return
|
||||||
|
|
||||||
total_steps = 0
|
total_steps = 0
|
||||||
|
|
||||||
# Loop through each subfolder and extract the number of repeats
|
# Loop through each subfolder and extract the number of repeats
|
||||||
for folder in subfolders:
|
for folder in subfolders:
|
||||||
# Extract the number of repeats from the folder name
|
# Extract the number of repeats from the folder name
|
||||||
repeats = int(folder.split('_')[0])
|
try:
|
||||||
|
repeats = int(folder.split('_')[0])
|
||||||
|
except ValueError:
|
||||||
|
print('\033[33mSubfolder', folder, 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m')
|
||||||
|
continue
|
||||||
|
|
||||||
# Count the number of images in the folder
|
# Count the number of images in the folder
|
||||||
num_images = len(
|
num_images = len(
|
||||||
@ -357,12 +366,19 @@ def train_model(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate the total number of steps for this folder
|
if num_images == 0:
|
||||||
steps = repeats * num_images
|
print(f'{folder} folder contain no images, skipping...')
|
||||||
total_steps += steps
|
else:
|
||||||
|
# Calculate the total number of steps for this folder
|
||||||
|
steps = repeats * num_images
|
||||||
|
total_steps += steps
|
||||||
|
|
||||||
# Print the result
|
# Print the result
|
||||||
print(f'Folder {folder}: {steps} steps')
|
print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
|
||||||
|
|
||||||
|
if total_steps == 0:
|
||||||
|
print('\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m')
|
||||||
|
return
|
||||||
|
|
||||||
# Print the result
|
# Print the result
|
||||||
# print(f"{total_steps} total steps")
|
# print(f"{total_steps} total steps")
|
||||||
@ -370,9 +386,7 @@ def train_model(
|
|||||||
if reg_data_dir == '':
|
if reg_data_dir == '':
|
||||||
reg_factor = 1
|
reg_factor = 1
|
||||||
else:
|
else:
|
||||||
print(
|
print('\033[94mRegularisation images are used... Will double the number of steps required...\033[0m')
|
||||||
'Regularisation images are used... Will double the number of steps required...'
|
|
||||||
)
|
|
||||||
reg_factor = 2
|
reg_factor = 2
|
||||||
|
|
||||||
# calculate max_train_steps
|
# calculate max_train_steps
|
||||||
|
@ -932,7 +932,8 @@ def gradio_advanced_training():
|
|||||||
label='VAE batch size',
|
label='VAE batch size',
|
||||||
minimum=0,
|
minimum=0,
|
||||||
maximum=32,
|
maximum=32,
|
||||||
value=0
|
value=0,
|
||||||
|
every=1
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||||
|
@ -23,6 +23,7 @@ def extract_lora(
|
|||||||
dim,
|
dim,
|
||||||
v2,
|
v2,
|
||||||
conv_dim,
|
conv_dim,
|
||||||
|
device,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if model_tuned == '':
|
if model_tuned == '':
|
||||||
@ -50,6 +51,7 @@ def extract_lora(
|
|||||||
run_cmd += f' --model_org "{model_org}"'
|
run_cmd += f' --model_org "{model_org}"'
|
||||||
run_cmd += f' --model_tuned "{model_tuned}"'
|
run_cmd += f' --model_tuned "{model_tuned}"'
|
||||||
run_cmd += f' --dim {dim}'
|
run_cmd += f' --dim {dim}'
|
||||||
|
run_cmd += f' --device {device}'
|
||||||
if conv_dim > 0:
|
if conv_dim > 0:
|
||||||
run_cmd += f' --conv_dim {conv_dim}'
|
run_cmd += f' --conv_dim {conv_dim}'
|
||||||
if v2:
|
if v2:
|
||||||
@ -148,6 +150,15 @@ def gradio_extract_lora_tab():
|
|||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
|
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
|
||||||
|
device = gr.Dropdown(
|
||||||
|
label='Device',
|
||||||
|
choices=[
|
||||||
|
'cpu',
|
||||||
|
'cuda',
|
||||||
|
],
|
||||||
|
value='cuda',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
extract_button = gr.Button('Extract LoRA model')
|
extract_button = gr.Button('Extract LoRA model')
|
||||||
|
|
||||||
@ -161,6 +172,7 @@ def gradio_extract_lora_tab():
|
|||||||
dim,
|
dim,
|
||||||
v2,
|
v2,
|
||||||
conv_dim,
|
conv_dim,
|
||||||
|
device
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
@ -11,8 +11,8 @@ import library.model_util as model_util
|
|||||||
import lora
|
import lora
|
||||||
|
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 1
|
||||||
MIN_DIFF = 1e-6
|
MIN_DIFF = 1e-8
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, model, state_dict, dtype):
|
||||||
@ -121,12 +121,12 @@ def svd(args):
|
|||||||
|
|
||||||
Vh = Vh[:rank, :]
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
# dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
low_val = -hi_val
|
# low_val = -hi_val
|
||||||
|
|
||||||
U = U.clamp(low_val, hi_val)
|
# U = U.clamp(low_val, hi_val)
|
||||||
Vh = Vh.clamp(low_val, hi_val)
|
# Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
U = U.reshape(out_dim, rank, 1, 1)
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user