- Finetune, add xformers, 8bit adam, min bucket, max bucket, batch size and flip augmentation support for dataset preparation

- Finetune, add "Dataset preparation" tab to group task specific options
This commit is contained in:
bmaltais 2023-01-02 13:07:17 -05:00
parent 1d460a09fd
commit 9d3c402973
4 changed files with 254 additions and 221 deletions

View File

@ -30,6 +30,9 @@ Once you have created the LoRA network you can generate images via auto1111 by i
## Change history ## Change history
* 2023/01/02 (v19.2) update:
- Finetune, add xformers, 8bit adam, min bucket, max bucket, batch size and flip augmentation support for dataset preparation
- Finetune, add "Dataset preparation" tab to group task specific options
* 2023/01/01 (v19.2) update: * 2023/01/01 (v19.2) update:
- add support for color and flip augmentation to "Dreambooth LoRA" - add support for color and flip augmentation to "Dreambooth LoRA"
* 2023/01/01 (v19.1) update: * 2023/01/01 (v19.1) update:

View File

@ -17,7 +17,7 @@ from library.common_gui import (
get_file_path, get_file_path,
get_any_file_path, get_any_file_path,
get_saveasfile_path, get_saveasfile_path,
color_aug_changed color_aug_changed,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -66,7 +66,9 @@ def save_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, color_aug, flip_aug prior_loss_weight,
color_aug,
flip_aug,
): ):
original_file_path = file_path original_file_path = file_path
@ -163,7 +165,9 @@ def open_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, color_aug, flip_aug prior_loss_weight,
color_aug,
flip_aug,
): ):
original_file_path = file_path original_file_path = file_path
@ -254,7 +258,9 @@ def train_model(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, color_aug, flip_aug prior_loss_weight,
color_aug,
flip_aug,
): ):
def save_inference_file(output_dir, v2, v_parameterization): def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required # Copy inference model for v2 if required
@ -774,10 +780,12 @@ def dreambooth_tab(
color_aug = gr.Checkbox( color_aug = gr.Checkbox(
label='Color augmentation', value=False label='Color augmentation', value=False
) )
flip_aug = gr.Checkbox( flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
label='Flip augmentation', value=False color_aug.change(
color_aug_changed,
inputs=[color_aug],
outputs=[cache_latent_input],
) )
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
with gr.Row(): with gr.Row():
resume = gr.Textbox( resume = gr.Textbox(
label='Resume from saved training state', label='Resume from saved training state',
@ -789,7 +797,9 @@ def dreambooth_tab(
label='Prior loss weight', value=1.0 label='Prior loss weight', value=1.0
) )
with gr.Tab('Tools'): with gr.Tab('Tools'):
gr.Markdown('This section provide Dreambooth tools to help setup your dataset...') gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab( gradio_dreambooth_folder_creation_tab(
train_data_dir_input=train_data_dir_input, train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input, reg_data_dir_input=reg_data_dir_input,
@ -835,7 +845,9 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, color_aug, flip_aug prior_loss_weight,
color_aug,
flip_aug,
], ],
outputs=[ outputs=[
config_file_name, config_file_name,
@ -870,7 +882,9 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, color_aug, flip_aug prior_loss_weight,
color_aug,
flip_aug,
], ],
) )
@ -910,7 +924,9 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, color_aug, flip_aug prior_loss_weight,
color_aug,
flip_aug,
], ],
outputs=[config_file_name], outputs=[config_file_name],
) )
@ -951,7 +967,9 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, color_aug, flip_aug prior_loss_weight,
color_aug,
flip_aug,
], ],
outputs=[config_file_name], outputs=[config_file_name],
) )
@ -990,7 +1008,9 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, color_aug, flip_aug prior_loss_weight,
color_aug,
flip_aug,
], ],
) )

View File

@ -31,6 +31,13 @@ def save_configuration(
output_dir, output_dir,
logging_dir, logging_dir,
max_resolution, max_resolution,
min_bucket_reso,
max_bucket_reso,
batch_size,
flip_aug,
caption_metadata_filename,
latent_metadata_filename,
full_path,
learning_rate, learning_rate,
lr_scheduler, lr_scheduler,
lr_warmup, lr_warmup,
@ -43,10 +50,12 @@ def save_configuration(
seed, seed,
num_cpu_threads_per_process, num_cpu_threads_per_process,
train_text_encoder, train_text_encoder,
create_buckets,
create_caption, create_caption,
create_buckets,
save_model_as, save_model_as,
caption_extension, caption_extension,
use_8bit_adam,
xformers,
): ):
original_file_path = file_path original_file_path = file_path
@ -75,6 +84,13 @@ def save_configuration(
'output_dir': output_dir, 'output_dir': output_dir,
'logging_dir': logging_dir, 'logging_dir': logging_dir,
'max_resolution': max_resolution, 'max_resolution': max_resolution,
'min_bucket_reso': min_bucket_reso,
'max_bucket_reso': max_bucket_reso,
'batch_size': batch_size,
'flip_aug': flip_aug,
'caption_metadata_filename': caption_metadata_filename,
'latent_metadata_filename': latent_metadata_filename,
'full_path': full_path,
'learning_rate': learning_rate, 'learning_rate': learning_rate,
'lr_scheduler': lr_scheduler, 'lr_scheduler': lr_scheduler,
'lr_warmup': lr_warmup, 'lr_warmup': lr_warmup,
@ -91,6 +107,8 @@ def save_configuration(
'create_caption': create_caption, 'create_caption': create_caption,
'save_model_as': save_model_as, 'save_model_as': save_model_as,
'caption_extension': caption_extension, 'caption_extension': caption_extension,
'use_8bit_adam': use_8bit_adam,
'xformers': xformers,
} }
# Save the data to the selected file # Save the data to the selected file
@ -110,6 +128,13 @@ def open_config_file(
output_dir, output_dir,
logging_dir, logging_dir,
max_resolution, max_resolution,
min_bucket_reso,
max_bucket_reso,
batch_size,
flip_aug,
caption_metadata_filename,
latent_metadata_filename,
full_path,
learning_rate, learning_rate,
lr_scheduler, lr_scheduler,
lr_warmup, lr_warmup,
@ -122,10 +147,12 @@ def open_config_file(
seed, seed,
num_cpu_threads_per_process, num_cpu_threads_per_process,
train_text_encoder, train_text_encoder,
create_buckets,
create_caption, create_caption,
create_buckets,
save_model_as, save_model_as,
caption_extension, caption_extension,
use_8bit_adam,
xformers,
): ):
original_file_path = file_path original_file_path = file_path
file_path = get_file_path(file_path) file_path = get_file_path(file_path)
@ -152,6 +179,13 @@ def open_config_file(
my_data.get('output_dir', output_dir), my_data.get('output_dir', output_dir),
my_data.get('logging_dir', logging_dir), my_data.get('logging_dir', logging_dir),
my_data.get('max_resolution', max_resolution), my_data.get('max_resolution', max_resolution),
my_data.get('min_bucket_reso', min_bucket_reso),
my_data.get('max_bucket_reso', max_bucket_reso),
my_data.get('batch_size', batch_size),
my_data.get('flip_aug', flip_aug),
my_data.get('caption_metadata_filename', caption_metadata_filename),
my_data.get('latent_metadata_filename', latent_metadata_filename),
my_data.get('full_path', full_path),
my_data.get('learning_rate', learning_rate), my_data.get('learning_rate', learning_rate),
my_data.get('lr_scheduler', lr_scheduler), my_data.get('lr_scheduler', lr_scheduler),
my_data.get('lr_warmup', lr_warmup), my_data.get('lr_warmup', lr_warmup),
@ -170,12 +204,12 @@ def open_config_file(
my_data.get('create_caption', create_caption), my_data.get('create_caption', create_caption),
my_data.get('save_model_as', save_model_as), my_data.get('save_model_as', save_model_as),
my_data.get('caption_extension', caption_extension), my_data.get('caption_extension', caption_extension),
my_data.get('use_8bit_adam', use_8bit_adam),
my_data.get('xformers', xformers),
) )
def train_model( def train_model(
generate_caption_database,
generate_image_buckets,
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
v_parameterization, v_parameterization,
@ -184,6 +218,13 @@ def train_model(
output_dir, output_dir,
logging_dir, logging_dir,
max_resolution, max_resolution,
min_bucket_reso,
max_bucket_reso,
batch_size,
flip_aug,
caption_metadata_filename,
latent_metadata_filename,
full_path,
learning_rate, learning_rate,
lr_scheduler, lr_scheduler,
lr_warmup, lr_warmup,
@ -196,8 +237,12 @@ def train_model(
seed, seed,
num_cpu_threads_per_process, num_cpu_threads_per_process,
train_text_encoder, train_text_encoder,
generate_caption_database,
generate_image_buckets,
save_model_as, save_model_as,
caption_extension, caption_extension,
use_8bit_adam,
xformers,
): ):
def save_inference_file(output_dir, v2, v_parameterization): def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required # Copy inference model for v2 if required
@ -227,8 +272,9 @@ def train_model(
else: else:
run_cmd += f' --caption_extension={caption_extension}' run_cmd += f' --caption_extension={caption_extension}'
run_cmd += f' {image_folder}' run_cmd += f' {image_folder}'
run_cmd += f' {train_dir}/meta_cap.json' run_cmd += f' {train_dir}/{caption_metadata_filename}'
run_cmd += f' --full_path' if full_path:
run_cmd += f' --full_path'
print(run_cmd) print(run_cmd)
@ -237,26 +283,27 @@ def train_model(
# create images buckets # create images buckets
if generate_image_buckets: if generate_image_buckets:
command = [ run_cmd = (
'./venv/Scripts/python.exe', f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py'
'finetune/prepare_buckets_latents.py', )
image_folder, run_cmd += f' {image_folder}'
'{}/meta_cap.json'.format(train_dir), run_cmd += f' {train_dir}/{caption_metadata_filename}'
'{}/meta_lat.json'.format(train_dir), run_cmd += f' {train_dir}/{latent_metadata_filename}'
pretrained_model_name_or_path, run_cmd += f' {pretrained_model_name_or_path}'
'--batch_size', run_cmd += f' --batch_size={batch_size}'
'4', run_cmd += f' --max_resolution={max_resolution}'
'--max_resolution', run_cmd += f' --min_bucket_reso={min_bucket_reso}'
max_resolution, run_cmd += f' --max_bucket_reso={max_bucket_reso}'
'--mixed_precision', run_cmd += f' --mixed_precision={mixed_precision}'
mixed_precision, if flip_aug:
'--full_path', run_cmd += f' --flip_aug'
] if full_path:
run_cmd += f' --full_path'
print(command) print(run_cmd)
# Run the command # Run the command
subprocess.run(command) subprocess.run(run_cmd)
image_num = len( image_num = len(
[f for f in os.listdir(image_folder) if f.endswith('.npz')] [f for f in os.listdir(image_folder) if f.endswith('.npz')]
@ -270,11 +317,14 @@ def train_model(
max_train_steps = int( max_train_steps = int(
math.ceil(float(repeats) / int(train_batch_size) * int(epoch)) math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
) )
# Divide by two because flip augmentation create two copied of the source images
if flip_aug:
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
print(f'max_train_steps = {max_train_steps}') print(f'max_train_steps = {max_train_steps}')
lr_warmup_steps = round( lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
float(int(lr_warmup) * int(max_train_steps) / 100)
)
print(f'lr_warmup_steps = {lr_warmup_steps}') print(f'lr_warmup_steps = {lr_warmup_steps}')
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"' run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
@ -284,10 +334,14 @@ def train_model(
run_cmd += ' --v_parameterization' run_cmd += ' --v_parameterization'
if train_text_encoder: if train_text_encoder:
run_cmd += ' --train_text_encoder' run_cmd += ' --train_text_encoder'
if use_8bit_adam:
run_cmd += f' --use_8bit_adam'
if xformers:
run_cmd += f' --xformers'
run_cmd += ( run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
) )
run_cmd += f' --in_json={train_dir}/meta_lat.json' run_cmd += f' --in_json={train_dir}/{latent_metadata_filename}'
run_cmd += f' --train_data_dir={image_folder}' run_cmd += f' --train_data_dir={image_folder}'
run_cmd += f' --output_dir={output_dir}' run_cmd += f' --output_dir={output_dir}'
if not logging_dir == '': if not logging_dir == '':
@ -298,8 +352,6 @@ def train_model(
run_cmd += f' --lr_scheduler={lr_scheduler}' run_cmd += f' --lr_scheduler={lr_scheduler}'
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}' run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
run_cmd += f' --max_train_steps={max_train_steps}' run_cmd += f' --max_train_steps={max_train_steps}'
run_cmd += f' --use_8bit_adam'
run_cmd += f' --xformers'
run_cmd += f' --mixed_precision={mixed_precision}' run_cmd += f' --mixed_precision={mixed_precision}'
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}' run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
run_cmd += f' --seed={seed}' run_cmd += f' --seed={seed}'
@ -389,9 +441,9 @@ def UI(username, password):
interface = gr.Blocks(css=css) interface = gr.Blocks(css=css)
with interface: with interface:
with gr.Tab("Finetune"): with gr.Tab('Finetune'):
finetune_tab() finetune_tab()
with gr.Tab("Utilities"): with gr.Tab('Utilities'):
utilities_tab(enable_dreambooth_tab=False) utilities_tab(enable_dreambooth_tab=False)
# Show the interface # Show the interface
@ -400,12 +452,11 @@ def UI(username, password):
else: else:
interface.launch() interface.launch()
def finetune_tab(): def finetune_tab():
dummy_ft_true = gr.Label(value=True, visible=False) dummy_ft_true = gr.Label(value=True, visible=False)
dummy_ft_false = gr.Label(value=False, visible=False) dummy_ft_false = gr.Label(value=False, visible=False)
gr.Markdown( gr.Markdown('Train a custom model using kohya finetune python code...')
'Train a custom model using kohya finetune python code...'
)
with gr.Accordion('Configuration file', open=False): with gr.Accordion('Configuration file', open=False):
with gr.Row(): with gr.Row():
button_open_config = gr.Button( button_open_config = gr.Button(
@ -496,9 +547,7 @@ def finetune_tab():
train_dir_folder = gr.Button( train_dir_folder = gr.Button(
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
train_dir_folder.click( train_dir_folder.click(get_folder_path, outputs=train_dir_input)
get_folder_path, outputs=train_dir_input
)
image_folder_input = gr.Textbox( image_folder_input = gr.Textbox(
label='Training Image folder', label='Training Image folder',
@ -547,11 +596,31 @@ def finetune_tab():
inputs=[output_dir_input], inputs=[output_dir_input],
outputs=[output_dir_input], outputs=[output_dir_input],
) )
with gr.Tab('Dataset preparation'):
with gr.Row():
max_resolution_input = gr.Textbox(
label='Resolution (width,height)', value='512,512'
)
min_bucket_reso = gr.Textbox(
label='Min bucket resolution', value='256'
)
max_bucket_reso = gr.Textbox(
label='Max bucket resolution', value='1024'
)
batch_size = gr.Textbox(label='Batch size', value='1')
with gr.Accordion('Advanced parameters', open=False):
with gr.Row():
caption_metadata_filename = gr.Textbox(
label='Caption metadata filename', value='meta_cap.json'
)
latent_metadata_filename = gr.Textbox(
label='Latent metadata filename', value='meta_lat.json'
)
full_path = gr.Checkbox(label='Use full path', value=True)
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
with gr.Tab('Training parameters'): with gr.Tab('Training parameters'):
with gr.Row(): with gr.Row():
learning_rate_input = gr.Textbox( learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
label='Learning rate', value=1e-6
)
lr_scheduler_input = gr.Dropdown( lr_scheduler_input = gr.Dropdown(
label='LR Scheduler', label='LR Scheduler',
choices=[ choices=[
@ -606,11 +675,7 @@ def finetune_tab():
label='Number of CPU threads per process', label='Number of CPU threads per process',
value=os.cpu_count(), value=os.cpu_count(),
) )
with gr.Row():
seed_input = gr.Textbox(label='Seed', value=1234) seed_input = gr.Textbox(label='Seed', value=1234)
max_resolution_input = gr.Textbox(
label='Max resolution', value='512,512'
)
with gr.Row(): with gr.Row():
caption_extention_input = gr.Textbox( caption_extention_input = gr.Textbox(
label='Caption Extension', label='Caption Extension',
@ -619,168 +684,74 @@ def finetune_tab():
train_text_encoder_input = gr.Checkbox( train_text_encoder_input = gr.Checkbox(
label='Train text encoder', value=True label='Train text encoder', value=True
) )
with gr.Accordion('Advanced parameters', open=False):
with gr.Row():
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True)
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
create_caption = gr.Checkbox( create_caption = gr.Checkbox(
label='Generate caption database', value=True label='Generate caption metadata', value=True
) )
create_buckets = gr.Checkbox( create_buckets = gr.Checkbox(
label='Generate image buckets', value=True label='Generate image buckets metadata', value=True
) )
button_run = gr.Button('Train model') button_run = gr.Button('Train model')
button_run.click( settings_list = [
train_model, pretrained_model_name_or_path_input,
inputs=[ v2_input,
create_caption, v_parameterization_input,
create_buckets, train_dir_input,
pretrained_model_name_or_path_input, image_folder_input,
v2_input, output_dir_input,
v_parameterization_input, logging_dir_input,
train_dir_input, max_resolution_input,
image_folder_input, min_bucket_reso,
output_dir_input, max_bucket_reso,
logging_dir_input, batch_size,
max_resolution_input, flip_aug,
learning_rate_input, caption_metadata_filename,
lr_scheduler_input, latent_metadata_filename,
lr_warmup_input, full_path,
dataset_repeats_input, learning_rate_input,
train_batch_size_input, lr_scheduler_input,
epoch_input, lr_warmup_input,
save_every_n_epochs_input, dataset_repeats_input,
mixed_precision_input, train_batch_size_input,
save_precision_input, epoch_input,
seed_input, save_every_n_epochs_input,
num_cpu_threads_per_process_input, mixed_precision_input,
train_text_encoder_input, save_precision_input,
save_model_as_dropdown, seed_input,
caption_extention_input, num_cpu_threads_per_process_input,
], train_text_encoder_input,
) create_caption,
create_buckets,
save_model_as_dropdown,
caption_extention_input,
use_8bit_adam,
xformers,
]
button_run.click(train_model, inputs=settings_list)
button_open_config.click( button_open_config.click(
open_config_file, open_config_file,
inputs=[ inputs=[config_file_name] + settings_list,
config_file_name, outputs=[config_file_name] + settings_list,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
train_dir_input,
image_folder_input,
output_dir_input,
logging_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
dataset_repeats_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
train_text_encoder_input,
create_buckets,
create_caption,
save_model_as_dropdown,
caption_extention_input,
],
outputs=[
config_file_name,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
train_dir_input,
image_folder_input,
output_dir_input,
logging_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
dataset_repeats_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
train_text_encoder_input,
create_buckets,
create_caption,
save_model_as_dropdown,
caption_extention_input,
],
) )
button_save_config.click( button_save_config.click(
save_configuration, save_configuration,
inputs=[ inputs=[dummy_ft_false, config_file_name] + settings_list,
dummy_ft_false,
config_file_name,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
train_dir_input,
image_folder_input,
output_dir_input,
logging_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
dataset_repeats_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
train_text_encoder_input,
create_buckets,
create_caption,
save_model_as_dropdown,
caption_extention_input,
],
outputs=[config_file_name], outputs=[config_file_name],
) )
button_save_as_config.click( button_save_as_config.click(
save_configuration, save_configuration,
inputs=[ inputs=[dummy_ft_true, config_file_name] + settings_list,
dummy_ft_true,
config_file_name,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
train_dir_input,
image_folder_input,
output_dir_input,
logging_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
dataset_repeats_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
train_text_encoder_input,
create_buckets,
create_caption,
save_model_as_dropdown,
caption_extention_input,
],
outputs=[config_file_name], outputs=[config_file_name],
) )

View File

@ -17,7 +17,7 @@ from library.common_gui import (
get_file_path, get_file_path,
get_any_file_path, get_any_file_path,
get_saveasfile_path, get_saveasfile_path,
color_aug_changed color_aug_changed,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -65,7 +65,13 @@ def save_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug prior_loss_weight,
text_encoder_lr,
unet_lr,
network_dim,
lora_network_weights,
color_aug,
flip_aug,
): ):
original_file_path = file_path original_file_path = file_path
@ -164,7 +170,13 @@ def open_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug prior_loss_weight,
text_encoder_lr,
unet_lr,
network_dim,
lora_network_weights,
color_aug,
flip_aug,
): ):
original_file_path = file_path original_file_path = file_path
@ -257,7 +269,13 @@ def train_model(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug prior_loss_weight,
text_encoder_lr,
unet_lr,
network_dim,
lora_network_weights,
color_aug,
flip_aug,
): ):
def save_inference_file(output_dir, v2, v_parameterization): def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required # Copy inference model for v2 if required
@ -296,11 +314,15 @@ def train_model(
return return
# If string is empty set string to 0. # If string is empty set string to 0.
if text_encoder_lr == '': text_encoder_lr = 0 if text_encoder_lr == '':
if unet_lr == '': unet_lr = 0 text_encoder_lr = 0
if unet_lr == '':
unet_lr = 0
if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0): if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
msgbox('At least one Learning Rate value for "Text encoder" or "Unet" need to be provided') msgbox(
'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
)
return return
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir
@ -446,7 +468,6 @@ def train_model(
if not lora_network_weights == '': if not lora_network_weights == '':
run_cmd += f' --network_weights={lora_network_weights}' run_cmd += f' --network_weights={lora_network_weights}'
print(run_cmd) print(run_cmd)
# Run the command # Run the command
subprocess.run(run_cmd) subprocess.run(run_cmd)
@ -552,7 +573,9 @@ def lora_tab(
): ):
dummy_db_true = gr.Label(value=True, visible=False) dummy_db_true = gr.Label(value=True, visible=False)
dummy_db_false = gr.Label(value=False, visible=False) dummy_db_false = gr.Label(value=False, visible=False)
gr.Markdown('Train a custom model using kohya train network LoRA python code...') gr.Markdown(
'Train a custom model using kohya train network LoRA python code...'
)
with gr.Accordion('Configuration file', open=False): with gr.Accordion('Configuration file', open=False):
with gr.Row(): with gr.Row():
button_open_config = gr.Button('Open 📂', elem_id='open_folder') button_open_config = gr.Button('Open 📂', elem_id='open_folder')
@ -729,8 +752,14 @@ def lora_tab(
) )
lr_warmup_input = gr.Textbox(label='LR warmup', value=0) lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
with gr.Row(): with gr.Row():
text_encoder_lr = gr.Textbox(label='Text Encoder learning rate', value=1e-6, placeholder='Optional') text_encoder_lr = gr.Textbox(
unet_lr = gr.Textbox(label='Unet learning rate', value=1e-4, placeholder='Optional') label='Text Encoder learning rate',
value=1e-6,
placeholder='Optional',
)
unet_lr = gr.Textbox(
label='Unet learning rate', value=1e-4, placeholder='Optional'
)
# network_train = gr.Dropdown( # network_train = gr.Dropdown(
# label='Network to train', # label='Network to train',
# choices=[ # choices=[
@ -747,7 +776,7 @@ def lora_tab(
label='Network Dimension', label='Network Dimension',
value=4, value=4,
step=1, step=1,
interactive=True interactive=True,
) )
with gr.Row(): with gr.Row():
train_batch_size_input = gr.Slider( train_batch_size_input = gr.Slider(
@ -837,10 +866,12 @@ def lora_tab(
color_aug = gr.Checkbox( color_aug = gr.Checkbox(
label='Color augmentation', value=False label='Color augmentation', value=False
) )
flip_aug = gr.Checkbox( flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
label='Flip augmentation', value=False color_aug.change(
color_aug_changed,
inputs=[color_aug],
outputs=[cache_latent_input],
) )
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
with gr.Row(): with gr.Row():
resume = gr.Textbox( resume = gr.Textbox(
label='Resume from saved training state', label='Resume from saved training state',
@ -852,7 +883,9 @@ def lora_tab(
label='Prior loss weight', value=1.0 label='Prior loss weight', value=1.0
) )
with gr.Tab('Tools'): with gr.Tab('Tools'):
gr.Markdown('This section provide Dreambooth tools to help setup your dataset...') gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab( gradio_dreambooth_folder_creation_tab(
train_data_dir_input=train_data_dir_input, train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input, reg_data_dir_input=reg_data_dir_input,
@ -895,7 +928,13 @@ def lora_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug prior_loss_weight,
text_encoder_lr,
unet_lr,
network_dim,
lora_network_weights,
color_aug,
flip_aug,
] ]
button_open_config.click( button_open_config.click(