Add support for LoRA resizing

This commit is contained in:
bmaltais 2023-02-04 11:55:06 -05:00
parent 045750b46a
commit 2626214f8a
11 changed files with 337 additions and 23 deletions

View File

@ -143,7 +143,13 @@ Then redo the installation instruction within the kohya_ss venv.
## Change history ## Change history
* 2023/02/03 * 2023/02/04 (v20.6.1)
- ``--persistent_data_loader_workers`` option is added to ``fine_tune.py``, ``train_db.py`` and ``train_network.py``. This option may significantly reduce the waiting time between epochs. Thanks to hitomi!
- ``--debug_dataset`` option is now working on non-Windows environment. Thanks to tsukimiya!
- ``networks/resize_lora.py`` script is added. This can approximate the higher-rank (dim) LoRA model by a lower-rank LoRA model, e.g. 128 by 4. Thanks to mgz-dev!
- ``--help`` option shows usage.
- Currently the metadata is not copied. This will be fixed in the near future.
* 2023/02/03 (v20.6.0)
- Increase max LoRA rank (dim) size to 1024. - Increase max LoRA rank (dim) size to 1024.
- Update finetune preprocessing scripts. - Update finetune preprocessing scripts.
- ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev! - ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev!

View File

@ -83,6 +83,7 @@ def save_configuration(
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, keep_tokens, model_list, keep_tokens,
persistent_data_loader_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -167,6 +168,7 @@ def open_configuration(
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, keep_tokens, model_list, keep_tokens,
persistent_data_loader_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -236,6 +238,7 @@ def train_model(
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
keep_tokens, keep_tokens,
persistent_data_loader_workers,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -398,6 +401,7 @@ def train_model(
xformers=xformers, xformers=xformers,
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens, keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers,
) )
print(run_cmd) print(run_cmd)
@ -605,6 +609,7 @@ def dreambooth_tab(
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -669,6 +674,7 @@ def dreambooth_tab(
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, model_list,
keep_tokens, keep_tokens,
persistent_data_loader_workers,
] ]
button_open_config.click( button_open_config.click(

View File

@ -79,6 +79,7 @@ def save_configuration(
model_list, model_list,
cache_latents, cache_latents,
use_latent_files, keep_tokens, use_latent_files, keep_tokens,
persistent_data_loader_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -169,6 +170,7 @@ def open_config_file(
model_list, model_list,
cache_latents, cache_latents,
use_latent_files, keep_tokens, use_latent_files, keep_tokens,
persistent_data_loader_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -244,6 +246,7 @@ def train_model(
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
cache_latents, cache_latents,
use_latent_files, keep_tokens, use_latent_files, keep_tokens,
persistent_data_loader_workers,
): ):
# create caption json file # create caption json file
if generate_caption_database: if generate_caption_database:
@ -382,6 +385,7 @@ def train_model(
xformers=xformers, xformers=xformers,
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens, keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers,
) )
print(run_cmd) print(run_cmd)
@ -587,6 +591,7 @@ def finetune_tab():
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -647,6 +652,7 @@ def finetune_tab():
cache_latents, cache_latents,
use_latent_files, use_latent_files,
keep_tokens, keep_tokens,
persistent_data_loader_workers,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)

View File

@ -510,31 +510,12 @@ def run_cmd_training(**kwargs):
def gradio_advanced_training(): def gradio_advanced_training():
with gr.Row(): with gr.Row():
full_fp16 = gr.Checkbox(
label='Full fp16 training (experimental)', value=False
)
gradient_checkpointing = gr.Checkbox(
label='Gradient checkpointing', value=False
)
shuffle_caption = gr.Checkbox(
label='Shuffle caption', value=False
)
keep_tokens = gr.Slider( keep_tokens = gr.Slider(
label='Keep n tokens', value='0', minimum=0, maximum=32, step=1 label='Keep n tokens', value='0', minimum=0, maximum=32, step=1
) )
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True)
with gr.Row():
color_aug = gr.Checkbox(
label='Color augmentation', value=False
)
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
clip_skip = gr.Slider( clip_skip = gr.Slider(
label='Clip skip', value='1', minimum=1, maximum=12, step=1 label='Clip skip', value='1', minimum=1, maximum=12, step=1
) )
mem_eff_attn = gr.Checkbox(
label='Memory efficient attention', value=False
)
max_token_length = gr.Dropdown( max_token_length = gr.Dropdown(
label='Max Token Length', label='Max Token Length',
choices=[ choices=[
@ -544,6 +525,29 @@ def gradio_advanced_training():
], ],
value='75', value='75',
) )
full_fp16 = gr.Checkbox(
label='Full fp16 training (experimental)', value=False
)
with gr.Row():
gradient_checkpointing = gr.Checkbox(
label='Gradient checkpointing', value=False
)
shuffle_caption = gr.Checkbox(
label='Shuffle caption', value=False
)
persistent_data_loader_workers = gr.Checkbox(
label='Persistent data loader', value=False
)
mem_eff_attn = gr.Checkbox(
label='Memory efficient attention', value=False
)
with gr.Row():
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True)
color_aug = gr.Checkbox(
label='Color augmentation', value=False
)
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
with gr.Row(): with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False) save_state = gr.Checkbox(label='Save training state', value=False)
resume = gr.Textbox( resume = gr.Textbox(
@ -576,6 +580,7 @@ def gradio_advanced_training():
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers,
) )
def run_cmd_advanced_training(**kwargs): def run_cmd_advanced_training(**kwargs):
@ -622,6 +627,8 @@ def run_cmd_advanced_training(**kwargs):
' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '', ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
' --persistent_data_loader_workers' if kwargs.get('persistent_data_loader_workers') else '',
] ]
run_cmd = ''.join(options) run_cmd = ''.join(options)
return run_cmd return run_cmd

104
library/resize_lora_gui.py Normal file
View File

@ -0,0 +1,104 @@
import gradio as gr
from easygui import msgbox
import subprocess
import os
from .common_gui import get_saveasfilename_path, get_file_path
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
def resize_lora(
model, new_rank, save_to, save_precision, device,
):
# Check for caption_text_input
if model == '':
msgbox('Invalid model file')
return
# Check if source model exist
if not os.path.isfile(model):
msgbox('The provided model is not a file')
return
if device == '':
device = 'cuda'
run_cmd = f'.\\venv\Scripts\python.exe "networks\\resize_lora.py"'
run_cmd += f' --save_precision {save_precision}'
run_cmd += f' --save_to {save_to}'
run_cmd += f' --model {model}'
run_cmd += f' --new_rank {new_rank}'
run_cmd += f' --device {device}'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
###
# Gradio UI
###
def gradio_resize_lora_tab():
with gr.Tab('Resize LoRA'):
gr.Markdown(
'This utility can resize a LoRA.'
)
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
with gr.Row():
model = gr.Textbox(
label='Source LoRA',
placeholder='Path to the LoRA to resize',
interactive=True,
)
button_lora_a_model_file = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_lora_a_model_file.click(
get_file_path,
inputs=[model, lora_ext, lora_ext_name],
outputs=model,
)
with gr.Row():
new_rank = gr.Slider(label="Desired LoRA rank", minimum=1, maximum=1024, step=1, value=4,
interactive=True,)
with gr.Row():
save_to = gr.Textbox(
label='Save to',
placeholder='path for the LoRA file to save...',
interactive=True,
)
button_save_to = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_save_to.click(
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
)
save_precision = gr.Dropdown(
label='Save precison',
choices=['fp16', 'bf16', 'float'],
value='fp16',
interactive=True,
)
device = gr.Textbox(
label='Device',
placeholder='{Optional) device to use, cuda for GPU. Default: cuda',
interactive=True,
)
convert_button = gr.Button('Resize model')
convert_button.click(
resize_lora,
inputs=[model, new_rank, save_to, save_precision, device,
],
)

View File

@ -772,7 +772,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV) im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
cv2.imshow("img", im) if os.name == 'nt': # only windows
cv2.imshow("img", im)
k = cv2.waitKey() k = cv2.waitKey()
cv2.destroyAllWindows() cv2.destroyAllWindows()
if k == 27: if k == 27:
@ -1194,6 +1195,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします") help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります") help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります")
parser.add_argument("--persistent_data_loader_workers", action="store_true",
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true", parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする") help="enable gradient checkpointing / grandient checkpointingを有効にする")

View File

@ -32,6 +32,7 @@ from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab from library.utilities import utilities_tab
from library.merge_lora_gui import gradio_merge_lora_tab from library.merge_lora_gui import gradio_merge_lora_tab
from library.verify_lora_gui import gradio_verify_lora_tab from library.verify_lora_gui import gradio_verify_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab
from easygui import msgbox from easygui import msgbox
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -92,6 +93,7 @@ def save_configuration(
network_alpha, network_alpha,
training_comment, keep_tokens, training_comment, keep_tokens,
lr_scheduler_num_cycles, lr_scheduler_power, lr_scheduler_num_cycles, lr_scheduler_power,
persistent_data_loader_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -182,6 +184,7 @@ def open_configuration(
network_alpha, network_alpha,
training_comment, keep_tokens, training_comment, keep_tokens,
lr_scheduler_num_cycles, lr_scheduler_power, lr_scheduler_num_cycles, lr_scheduler_power,
persistent_data_loader_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -256,6 +259,7 @@ def train_model(
network_alpha, network_alpha,
training_comment, keep_tokens, training_comment, keep_tokens,
lr_scheduler_num_cycles, lr_scheduler_power, lr_scheduler_num_cycles, lr_scheduler_power,
persistent_data_loader_workers,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -446,6 +450,7 @@ def train_model(
xformers=xformers, xformers=xformers,
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens, keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers,
) )
print(run_cmd) print(run_cmd)
@ -689,6 +694,7 @@ def lora_tab(
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -708,6 +714,7 @@ def lora_tab(
) )
gradio_dataset_balancing_tab() gradio_dataset_balancing_tab()
gradio_merge_lora_tab() gradio_merge_lora_tab()
gradio_resize_lora_tab()
gradio_verify_lora_tab() gradio_verify_lora_tab()
@ -764,6 +771,7 @@ def lora_tab(
training_comment, training_comment,
keep_tokens, keep_tokens,
lr_scheduler_num_cycles, lr_scheduler_power, lr_scheduler_num_cycles, lr_scheduler_power,
persistent_data_loader_workers,
] ]
button_open_config.click( button_open_config.click(

166
networks/resize_lora.py Normal file
View File

@ -0,0 +1,166 @@
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo and kohya
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
print("Loading Model...")
lora_sd = load_state_dict(model, merge_dtype)
network_alpha = None
network_dim = None
CLAMP_QUANTILE = 0.99
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}")
lora_down_weight = None
lora_up_weight = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
print("resizing lora...")
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
if (block_down_name == block_up_name) and weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4)
if conv2d:
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()
if args.device:
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
U, S, Vh = torch.linalg.svd(full_weight_matrix)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3)
if args.device:
U = U.to(org_device)
Vh = Vh.to(org_device)
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
print("resizing complete")
return o_lora_sd
def resize(args):
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args()
resize(args)

View File

@ -83,6 +83,7 @@ def save_configuration(
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
persistent_data_loader_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -171,6 +172,7 @@ def open_configuration(
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
persistent_data_loader_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -240,6 +242,7 @@ def train_model(
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
persistent_data_loader_workers,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -417,6 +420,7 @@ def train_model(
xformers=xformers, xformers=xformers,
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens, keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers,
) )
run_cmd += f' --token_string="{token_string}"' run_cmd += f' --token_string="{token_string}"'
run_cmd += f' --init_word="{init_word}"' run_cmd += f' --init_word="{init_word}"'
@ -671,6 +675,7 @@ def ti_tab(
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -736,6 +741,7 @@ def ti_tab(
model_list, model_list,
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
keep_tokens, keep_tokens,
persistent_data_loader_workers,
] ]
button_open_config.click( button_open_config.click(

View File

@ -133,7 +133,7 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
@ -176,6 +176,8 @@ def train(args):
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する # 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

View File

@ -214,7 +214,7 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None: