Add support for LoRA resizing
This commit is contained in:
parent
045750b46a
commit
2626214f8a
@ -143,7 +143,13 @@ Then redo the installation instruction within the kohya_ss venv.
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
* 2023/02/03
|
* 2023/02/04 (v20.6.1)
|
||||||
|
- ``--persistent_data_loader_workers`` option is added to ``fine_tune.py``, ``train_db.py`` and ``train_network.py``. This option may significantly reduce the waiting time between epochs. Thanks to hitomi!
|
||||||
|
- ``--debug_dataset`` option is now working on non-Windows environment. Thanks to tsukimiya!
|
||||||
|
- ``networks/resize_lora.py`` script is added. This can approximate the higher-rank (dim) LoRA model by a lower-rank LoRA model, e.g. 128 by 4. Thanks to mgz-dev!
|
||||||
|
- ``--help`` option shows usage.
|
||||||
|
- Currently the metadata is not copied. This will be fixed in the near future.
|
||||||
|
* 2023/02/03 (v20.6.0)
|
||||||
- Increase max LoRA rank (dim) size to 1024.
|
- Increase max LoRA rank (dim) size to 1024.
|
||||||
- Update finetune preprocessing scripts.
|
- Update finetune preprocessing scripts.
|
||||||
- ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev!
|
- ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev!
|
||||||
|
@ -83,6 +83,7 @@ def save_configuration(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, keep_tokens,
|
model_list, keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -167,6 +168,7 @@ def open_configuration(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, keep_tokens,
|
model_list, keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -236,6 +238,7 @@ def train_model(
|
|||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -398,6 +401,7 @@ def train_model(
|
|||||||
xformers=xformers,
|
xformers=xformers,
|
||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
keep_tokens=keep_tokens,
|
keep_tokens=keep_tokens,
|
||||||
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -605,6 +609,7 @@ def dreambooth_tab(
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -669,6 +674,7 @@ def dreambooth_tab(
|
|||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -79,6 +79,7 @@ def save_configuration(
|
|||||||
model_list,
|
model_list,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files, keep_tokens,
|
use_latent_files, keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -169,6 +170,7 @@ def open_config_file(
|
|||||||
model_list,
|
model_list,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files, keep_tokens,
|
use_latent_files, keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -244,6 +246,7 @@ def train_model(
|
|||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files, keep_tokens,
|
use_latent_files, keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
# create caption json file
|
# create caption json file
|
||||||
if generate_caption_database:
|
if generate_caption_database:
|
||||||
@ -382,6 +385,7 @@ def train_model(
|
|||||||
xformers=xformers,
|
xformers=xformers,
|
||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
keep_tokens=keep_tokens,
|
keep_tokens=keep_tokens,
|
||||||
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -587,6 +591,7 @@ def finetune_tab():
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -647,6 +652,7 @@ def finetune_tab():
|
|||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files,
|
use_latent_files,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
@ -510,31 +510,12 @@ def run_cmd_training(**kwargs):
|
|||||||
|
|
||||||
def gradio_advanced_training():
|
def gradio_advanced_training():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
full_fp16 = gr.Checkbox(
|
|
||||||
label='Full fp16 training (experimental)', value=False
|
|
||||||
)
|
|
||||||
gradient_checkpointing = gr.Checkbox(
|
|
||||||
label='Gradient checkpointing', value=False
|
|
||||||
)
|
|
||||||
shuffle_caption = gr.Checkbox(
|
|
||||||
label='Shuffle caption', value=False
|
|
||||||
)
|
|
||||||
keep_tokens = gr.Slider(
|
keep_tokens = gr.Slider(
|
||||||
label='Keep n tokens', value='0', minimum=0, maximum=32, step=1
|
label='Keep n tokens', value='0', minimum=0, maximum=32, step=1
|
||||||
)
|
)
|
||||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
|
||||||
with gr.Row():
|
|
||||||
color_aug = gr.Checkbox(
|
|
||||||
label='Color augmentation', value=False
|
|
||||||
)
|
|
||||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
|
||||||
clip_skip = gr.Slider(
|
clip_skip = gr.Slider(
|
||||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||||
)
|
)
|
||||||
mem_eff_attn = gr.Checkbox(
|
|
||||||
label='Memory efficient attention', value=False
|
|
||||||
)
|
|
||||||
max_token_length = gr.Dropdown(
|
max_token_length = gr.Dropdown(
|
||||||
label='Max Token Length',
|
label='Max Token Length',
|
||||||
choices=[
|
choices=[
|
||||||
@ -544,6 +525,29 @@ def gradio_advanced_training():
|
|||||||
],
|
],
|
||||||
value='75',
|
value='75',
|
||||||
)
|
)
|
||||||
|
full_fp16 = gr.Checkbox(
|
||||||
|
label='Full fp16 training (experimental)', value=False
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
gradient_checkpointing = gr.Checkbox(
|
||||||
|
label='Gradient checkpointing', value=False
|
||||||
|
)
|
||||||
|
shuffle_caption = gr.Checkbox(
|
||||||
|
label='Shuffle caption', value=False
|
||||||
|
)
|
||||||
|
persistent_data_loader_workers = gr.Checkbox(
|
||||||
|
label='Persistent data loader', value=False
|
||||||
|
)
|
||||||
|
mem_eff_attn = gr.Checkbox(
|
||||||
|
label='Memory efficient attention', value=False
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||||
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
|
color_aug = gr.Checkbox(
|
||||||
|
label='Color augmentation', value=False
|
||||||
|
)
|
||||||
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||||
resume = gr.Textbox(
|
resume = gr.Textbox(
|
||||||
@ -576,6 +580,7 @@ def gradio_advanced_training():
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_cmd_advanced_training(**kwargs):
|
def run_cmd_advanced_training(**kwargs):
|
||||||
@ -622,6 +627,8 @@ def run_cmd_advanced_training(**kwargs):
|
|||||||
|
|
||||||
' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
|
' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
|
||||||
|
|
||||||
|
' --persistent_data_loader_workers' if kwargs.get('persistent_data_loader_workers') else '',
|
||||||
|
|
||||||
]
|
]
|
||||||
run_cmd = ''.join(options)
|
run_cmd = ''.join(options)
|
||||||
return run_cmd
|
return run_cmd
|
||||||
|
104
library/resize_lora_gui.py
Normal file
104
library/resize_lora_gui.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from easygui import msgbox
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
from .common_gui import get_saveasfilename_path, get_file_path
|
||||||
|
|
||||||
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
|
|
||||||
|
def resize_lora(
|
||||||
|
model, new_rank, save_to, save_precision, device,
|
||||||
|
):
|
||||||
|
# Check for caption_text_input
|
||||||
|
if model == '':
|
||||||
|
msgbox('Invalid model file')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if source model exist
|
||||||
|
if not os.path.isfile(model):
|
||||||
|
msgbox('The provided model is not a file')
|
||||||
|
return
|
||||||
|
|
||||||
|
if device == '':
|
||||||
|
device = 'cuda'
|
||||||
|
|
||||||
|
run_cmd = f'.\\venv\Scripts\python.exe "networks\\resize_lora.py"'
|
||||||
|
run_cmd += f' --save_precision {save_precision}'
|
||||||
|
run_cmd += f' --save_to {save_to}'
|
||||||
|
run_cmd += f' --model {model}'
|
||||||
|
run_cmd += f' --new_rank {new_rank}'
|
||||||
|
run_cmd += f' --device {device}'
|
||||||
|
|
||||||
|
print(run_cmd)
|
||||||
|
|
||||||
|
# Run the command
|
||||||
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
|
# Gradio UI
|
||||||
|
###
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_resize_lora_tab():
|
||||||
|
with gr.Tab('Resize LoRA'):
|
||||||
|
gr.Markdown(
|
||||||
|
'This utility can resize a LoRA.'
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
||||||
|
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
model = gr.Textbox(
|
||||||
|
label='Source LoRA',
|
||||||
|
placeholder='Path to the LoRA to resize',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
button_lora_a_model_file = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
button_lora_a_model_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[model, lora_ext, lora_ext_name],
|
||||||
|
outputs=model,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
new_rank = gr.Slider(label="Desired LoRA rank", minimum=1, maximum=1024, step=1, value=4,
|
||||||
|
interactive=True,)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
save_to = gr.Textbox(
|
||||||
|
label='Save to',
|
||||||
|
placeholder='path for the LoRA file to save...',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
button_save_to = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
button_save_to.click(
|
||||||
|
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
|
||||||
|
)
|
||||||
|
save_precision = gr.Dropdown(
|
||||||
|
label='Save precison',
|
||||||
|
choices=['fp16', 'bf16', 'float'],
|
||||||
|
value='fp16',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
device = gr.Textbox(
|
||||||
|
label='Device',
|
||||||
|
placeholder='{Optional) device to use, cuda for GPU. Default: cuda',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_button = gr.Button('Resize model')
|
||||||
|
|
||||||
|
convert_button.click(
|
||||||
|
resize_lora,
|
||||||
|
inputs=[model, new_rank, save_to, save_precision, device,
|
||||||
|
],
|
||||||
|
)
|
@ -772,6 +772,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||||
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
||||||
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
||||||
|
if os.name == 'nt': # only windows
|
||||||
cv2.imshow("img", im)
|
cv2.imshow("img", im)
|
||||||
k = cv2.waitKey()
|
k = cv2.waitKey()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
@ -1194,6 +1195,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
||||||
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||||
|
parser.add_argument("--persistent_data_loader_workers", action="store_true",
|
||||||
|
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
|
||||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||||
|
@ -32,6 +32,7 @@ from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
|||||||
from library.utilities import utilities_tab
|
from library.utilities import utilities_tab
|
||||||
from library.merge_lora_gui import gradio_merge_lora_tab
|
from library.merge_lora_gui import gradio_merge_lora_tab
|
||||||
from library.verify_lora_gui import gradio_verify_lora_tab
|
from library.verify_lora_gui import gradio_verify_lora_tab
|
||||||
|
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||||
from easygui import msgbox
|
from easygui import msgbox
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
@ -92,6 +93,7 @@ def save_configuration(
|
|||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment, keep_tokens,
|
training_comment, keep_tokens,
|
||||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -182,6 +184,7 @@ def open_configuration(
|
|||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment, keep_tokens,
|
training_comment, keep_tokens,
|
||||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -256,6 +259,7 @@ def train_model(
|
|||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment, keep_tokens,
|
training_comment, keep_tokens,
|
||||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -446,6 +450,7 @@ def train_model(
|
|||||||
xformers=xformers,
|
xformers=xformers,
|
||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
keep_tokens=keep_tokens,
|
keep_tokens=keep_tokens,
|
||||||
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -689,6 +694,7 @@ def lora_tab(
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -708,6 +714,7 @@ def lora_tab(
|
|||||||
)
|
)
|
||||||
gradio_dataset_balancing_tab()
|
gradio_dataset_balancing_tab()
|
||||||
gradio_merge_lora_tab()
|
gradio_merge_lora_tab()
|
||||||
|
gradio_resize_lora_tab()
|
||||||
gradio_verify_lora_tab()
|
gradio_verify_lora_tab()
|
||||||
|
|
||||||
|
|
||||||
@ -764,6 +771,7 @@ def lora_tab(
|
|||||||
training_comment,
|
training_comment,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||||
|
persistent_data_loader_workers,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
166
networks/resize_lora.py
Normal file
166
networks/resize_lora.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
||||||
|
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||||
|
# Thanks to cloneofsimo and kohya
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def load_state_dict(file_name, dtype):
|
||||||
|
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||||
|
sd = load_file(file_name)
|
||||||
|
else:
|
||||||
|
sd = torch.load(file_name, map_location='cpu')
|
||||||
|
for key in list(sd.keys()):
|
||||||
|
if type(sd[key]) == torch.Tensor:
|
||||||
|
sd[key] = sd[key].to(dtype)
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_file(file_name, model, state_dict, dtype):
|
||||||
|
if dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if type(state_dict[key]) == torch.Tensor:
|
||||||
|
state_dict[key] = state_dict[key].to(dtype)
|
||||||
|
|
||||||
|
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||||
|
save_file(model, file_name)
|
||||||
|
else:
|
||||||
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
||||||
|
print("Loading Model...")
|
||||||
|
lora_sd = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
|
network_alpha = None
|
||||||
|
network_dim = None
|
||||||
|
|
||||||
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
# Extract loaded lora dim and alpha
|
||||||
|
for key, value in lora_sd.items():
|
||||||
|
if network_alpha is None and 'alpha' in key:
|
||||||
|
network_alpha = value
|
||||||
|
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
||||||
|
network_dim = value.size()[0]
|
||||||
|
if network_alpha is not None and network_dim is not None:
|
||||||
|
break
|
||||||
|
if network_alpha is None:
|
||||||
|
network_alpha = network_dim
|
||||||
|
|
||||||
|
scale = network_alpha/network_dim
|
||||||
|
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
||||||
|
|
||||||
|
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}")
|
||||||
|
|
||||||
|
lora_down_weight = None
|
||||||
|
lora_up_weight = None
|
||||||
|
|
||||||
|
o_lora_sd = lora_sd.copy()
|
||||||
|
block_down_name = None
|
||||||
|
block_up_name = None
|
||||||
|
|
||||||
|
print("resizing lora...")
|
||||||
|
with torch.no_grad():
|
||||||
|
for key, value in tqdm(lora_sd.items()):
|
||||||
|
if 'lora_down' in key:
|
||||||
|
block_down_name = key.split(".")[0]
|
||||||
|
lora_down_weight = value
|
||||||
|
if 'lora_up' in key:
|
||||||
|
block_up_name = key.split(".")[0]
|
||||||
|
lora_up_weight = value
|
||||||
|
|
||||||
|
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
||||||
|
|
||||||
|
if (block_down_name == block_up_name) and weights_loaded:
|
||||||
|
|
||||||
|
conv2d = (len(lora_down_weight.size()) == 4)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
lora_down_weight = lora_down_weight.squeeze()
|
||||||
|
lora_up_weight = lora_up_weight.squeeze()
|
||||||
|
|
||||||
|
if args.device:
|
||||||
|
org_device = lora_up_weight.device
|
||||||
|
lora_up_weight = lora_up_weight.to(args.device)
|
||||||
|
lora_down_weight = lora_down_weight.to(args.device)
|
||||||
|
|
||||||
|
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
||||||
|
|
||||||
|
U = U[:, :new_rank]
|
||||||
|
S = S[:new_rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
|
||||||
|
Vh = Vh[:new_rank, :]
|
||||||
|
|
||||||
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
low_val = -hi_val
|
||||||
|
|
||||||
|
U = U.clamp(low_val, hi_val)
|
||||||
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
U = U.unsqueeze(2).unsqueeze(3)
|
||||||
|
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
|
if args.device:
|
||||||
|
U = U.to(org_device)
|
||||||
|
Vh = Vh.to(org_device)
|
||||||
|
|
||||||
|
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
||||||
|
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
||||||
|
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
||||||
|
|
||||||
|
block_down_name = None
|
||||||
|
block_up_name = None
|
||||||
|
lora_down_weight = None
|
||||||
|
lora_up_weight = None
|
||||||
|
weights_loaded = False
|
||||||
|
|
||||||
|
print("resizing complete")
|
||||||
|
return o_lora_sd
|
||||||
|
|
||||||
|
def resize(args):
|
||||||
|
|
||||||
|
def str_to_dtype(p):
|
||||||
|
if p == 'float':
|
||||||
|
return torch.float
|
||||||
|
if p == 'fp16':
|
||||||
|
return torch.float16
|
||||||
|
if p == 'bf16':
|
||||||
|
return torch.bfloat16
|
||||||
|
return None
|
||||||
|
|
||||||
|
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
||||||
|
save_dtype = str_to_dtype(args.save_precision)
|
||||||
|
if save_dtype is None:
|
||||||
|
save_dtype = merge_dtype
|
||||||
|
|
||||||
|
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype)
|
||||||
|
|
||||||
|
print(f"saving model to: {args.save_to}")
|
||||||
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--save_precision", type=str, default=None,
|
||||||
|
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
|
||||||
|
parser.add_argument("--new_rank", type=int, default=4,
|
||||||
|
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||||
|
parser.add_argument("--save_to", type=str, default=None,
|
||||||
|
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||||
|
parser.add_argument("--model", type=str, default=None,
|
||||||
|
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
|
||||||
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
resize(args)
|
@ -83,6 +83,7 @@ def save_configuration(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -171,6 +172,7 @@ def open_configuration(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -240,6 +242,7 @@ def train_model(
|
|||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -417,6 +420,7 @@ def train_model(
|
|||||||
xformers=xformers,
|
xformers=xformers,
|
||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
keep_tokens=keep_tokens,
|
keep_tokens=keep_tokens,
|
||||||
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
run_cmd += f' --token_string="{token_string}"'
|
run_cmd += f' --token_string="{token_string}"'
|
||||||
run_cmd += f' --init_word="{init_word}"'
|
run_cmd += f' --init_word="{init_word}"'
|
||||||
@ -671,6 +675,7 @@ def ti_tab(
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -736,6 +741,7 @@ def ti_tab(
|
|||||||
model_list,
|
model_list,
|
||||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
persistent_data_loader_workers,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -133,7 +133,7 @@ def train(args):
|
|||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
@ -176,6 +176,8 @@ def train(args):
|
|||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||||
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||||
|
|
||||||
# 学習する
|
# 学習する
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
@ -214,7 +214,7 @@ def train(args):
|
|||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user