- Finetune, add xformers, 8bit adam, min bucket, max bucket, batch size and flip augmentation support for dataset preparation
- Finetune, add "Dataset preparation" tab to group task specific options
This commit is contained in:
parent
1d460a09fd
commit
9d3c402973
@ -30,6 +30,9 @@ Once you have created the LoRA network you can generate images via auto1111 by i
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 2023/01/02 (v19.2) update:
|
||||||
|
- Finetune, add xformers, 8bit adam, min bucket, max bucket, batch size and flip augmentation support for dataset preparation
|
||||||
|
- Finetune, add "Dataset preparation" tab to group task specific options
|
||||||
* 2023/01/01 (v19.2) update:
|
* 2023/01/01 (v19.2) update:
|
||||||
- add support for color and flip augmentation to "Dreambooth LoRA"
|
- add support for color and flip augmentation to "Dreambooth LoRA"
|
||||||
* 2023/01/01 (v19.1) update:
|
* 2023/01/01 (v19.1) update:
|
||||||
|
@ -17,7 +17,7 @@ from library.common_gui import (
|
|||||||
get_file_path,
|
get_file_path,
|
||||||
get_any_file_path,
|
get_any_file_path,
|
||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
color_aug_changed
|
color_aug_changed,
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
@ -66,7 +66,9 @@ def save_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
@ -163,7 +165,9 @@ def open_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
):
|
):
|
||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
@ -254,7 +258,9 @@ def train_model(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
):
|
):
|
||||||
def save_inference_file(output_dir, v2, v_parameterization):
|
def save_inference_file(output_dir, v2, v_parameterization):
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
@ -774,10 +780,12 @@ def dreambooth_tab(
|
|||||||
color_aug = gr.Checkbox(
|
color_aug = gr.Checkbox(
|
||||||
label='Color augmentation', value=False
|
label='Color augmentation', value=False
|
||||||
)
|
)
|
||||||
flip_aug = gr.Checkbox(
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
label='Flip augmentation', value=False
|
color_aug.change(
|
||||||
|
color_aug_changed,
|
||||||
|
inputs=[color_aug],
|
||||||
|
outputs=[cache_latent_input],
|
||||||
)
|
)
|
||||||
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resume = gr.Textbox(
|
resume = gr.Textbox(
|
||||||
label='Resume from saved training state',
|
label='Resume from saved training state',
|
||||||
@ -789,7 +797,9 @@ def dreambooth_tab(
|
|||||||
label='Prior loss weight', value=1.0
|
label='Prior loss weight', value=1.0
|
||||||
)
|
)
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown('This section provide Dreambooth tools to help setup your dataset...')
|
gr.Markdown(
|
||||||
|
'This section provide Dreambooth tools to help setup your dataset...'
|
||||||
|
)
|
||||||
gradio_dreambooth_folder_creation_tab(
|
gradio_dreambooth_folder_creation_tab(
|
||||||
train_data_dir_input=train_data_dir_input,
|
train_data_dir_input=train_data_dir_input,
|
||||||
reg_data_dir_input=reg_data_dir_input,
|
reg_data_dir_input=reg_data_dir_input,
|
||||||
@ -835,7 +845,9 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
config_file_name,
|
config_file_name,
|
||||||
@ -870,7 +882,9 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -910,7 +924,9 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
],
|
],
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
@ -951,7 +967,9 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
],
|
],
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
@ -990,7 +1008,9 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
341
finetune_gui.py
341
finetune_gui.py
@ -31,6 +31,13 @@ def save_configuration(
|
|||||||
output_dir,
|
output_dir,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
|
min_bucket_reso,
|
||||||
|
max_bucket_reso,
|
||||||
|
batch_size,
|
||||||
|
flip_aug,
|
||||||
|
caption_metadata_filename,
|
||||||
|
latent_metadata_filename,
|
||||||
|
full_path,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
@ -43,10 +50,12 @@ def save_configuration(
|
|||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
train_text_encoder,
|
train_text_encoder,
|
||||||
create_buckets,
|
|
||||||
create_caption,
|
create_caption,
|
||||||
|
create_buckets,
|
||||||
save_model_as,
|
save_model_as,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
@ -75,6 +84,13 @@ def save_configuration(
|
|||||||
'output_dir': output_dir,
|
'output_dir': output_dir,
|
||||||
'logging_dir': logging_dir,
|
'logging_dir': logging_dir,
|
||||||
'max_resolution': max_resolution,
|
'max_resolution': max_resolution,
|
||||||
|
'min_bucket_reso': min_bucket_reso,
|
||||||
|
'max_bucket_reso': max_bucket_reso,
|
||||||
|
'batch_size': batch_size,
|
||||||
|
'flip_aug': flip_aug,
|
||||||
|
'caption_metadata_filename': caption_metadata_filename,
|
||||||
|
'latent_metadata_filename': latent_metadata_filename,
|
||||||
|
'full_path': full_path,
|
||||||
'learning_rate': learning_rate,
|
'learning_rate': learning_rate,
|
||||||
'lr_scheduler': lr_scheduler,
|
'lr_scheduler': lr_scheduler,
|
||||||
'lr_warmup': lr_warmup,
|
'lr_warmup': lr_warmup,
|
||||||
@ -91,6 +107,8 @@ def save_configuration(
|
|||||||
'create_caption': create_caption,
|
'create_caption': create_caption,
|
||||||
'save_model_as': save_model_as,
|
'save_model_as': save_model_as,
|
||||||
'caption_extension': caption_extension,
|
'caption_extension': caption_extension,
|
||||||
|
'use_8bit_adam': use_8bit_adam,
|
||||||
|
'xformers': xformers,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save the data to the selected file
|
# Save the data to the selected file
|
||||||
@ -110,6 +128,13 @@ def open_config_file(
|
|||||||
output_dir,
|
output_dir,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
|
min_bucket_reso,
|
||||||
|
max_bucket_reso,
|
||||||
|
batch_size,
|
||||||
|
flip_aug,
|
||||||
|
caption_metadata_filename,
|
||||||
|
latent_metadata_filename,
|
||||||
|
full_path,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
@ -122,10 +147,12 @@ def open_config_file(
|
|||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
train_text_encoder,
|
train_text_encoder,
|
||||||
create_buckets,
|
|
||||||
create_caption,
|
create_caption,
|
||||||
|
create_buckets,
|
||||||
save_model_as,
|
save_model_as,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
file_path = get_file_path(file_path)
|
file_path = get_file_path(file_path)
|
||||||
@ -152,6 +179,13 @@ def open_config_file(
|
|||||||
my_data.get('output_dir', output_dir),
|
my_data.get('output_dir', output_dir),
|
||||||
my_data.get('logging_dir', logging_dir),
|
my_data.get('logging_dir', logging_dir),
|
||||||
my_data.get('max_resolution', max_resolution),
|
my_data.get('max_resolution', max_resolution),
|
||||||
|
my_data.get('min_bucket_reso', min_bucket_reso),
|
||||||
|
my_data.get('max_bucket_reso', max_bucket_reso),
|
||||||
|
my_data.get('batch_size', batch_size),
|
||||||
|
my_data.get('flip_aug', flip_aug),
|
||||||
|
my_data.get('caption_metadata_filename', caption_metadata_filename),
|
||||||
|
my_data.get('latent_metadata_filename', latent_metadata_filename),
|
||||||
|
my_data.get('full_path', full_path),
|
||||||
my_data.get('learning_rate', learning_rate),
|
my_data.get('learning_rate', learning_rate),
|
||||||
my_data.get('lr_scheduler', lr_scheduler),
|
my_data.get('lr_scheduler', lr_scheduler),
|
||||||
my_data.get('lr_warmup', lr_warmup),
|
my_data.get('lr_warmup', lr_warmup),
|
||||||
@ -170,12 +204,12 @@ def open_config_file(
|
|||||||
my_data.get('create_caption', create_caption),
|
my_data.get('create_caption', create_caption),
|
||||||
my_data.get('save_model_as', save_model_as),
|
my_data.get('save_model_as', save_model_as),
|
||||||
my_data.get('caption_extension', caption_extension),
|
my_data.get('caption_extension', caption_extension),
|
||||||
|
my_data.get('use_8bit_adam', use_8bit_adam),
|
||||||
|
my_data.get('xformers', xformers),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_model(
|
def train_model(
|
||||||
generate_caption_database,
|
|
||||||
generate_image_buckets,
|
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
@ -184,6 +218,13 @@ def train_model(
|
|||||||
output_dir,
|
output_dir,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
|
min_bucket_reso,
|
||||||
|
max_bucket_reso,
|
||||||
|
batch_size,
|
||||||
|
flip_aug,
|
||||||
|
caption_metadata_filename,
|
||||||
|
latent_metadata_filename,
|
||||||
|
full_path,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
@ -196,8 +237,12 @@ def train_model(
|
|||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
train_text_encoder,
|
train_text_encoder,
|
||||||
|
generate_caption_database,
|
||||||
|
generate_image_buckets,
|
||||||
save_model_as,
|
save_model_as,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
):
|
):
|
||||||
def save_inference_file(output_dir, v2, v_parameterization):
|
def save_inference_file(output_dir, v2, v_parameterization):
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
@ -227,8 +272,9 @@ def train_model(
|
|||||||
else:
|
else:
|
||||||
run_cmd += f' --caption_extension={caption_extension}'
|
run_cmd += f' --caption_extension={caption_extension}'
|
||||||
run_cmd += f' {image_folder}'
|
run_cmd += f' {image_folder}'
|
||||||
run_cmd += f' {train_dir}/meta_cap.json'
|
run_cmd += f' {train_dir}/{caption_metadata_filename}'
|
||||||
run_cmd += f' --full_path'
|
if full_path:
|
||||||
|
run_cmd += f' --full_path'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
|
|
||||||
@ -237,26 +283,27 @@ def train_model(
|
|||||||
|
|
||||||
# create images buckets
|
# create images buckets
|
||||||
if generate_image_buckets:
|
if generate_image_buckets:
|
||||||
command = [
|
run_cmd = (
|
||||||
'./venv/Scripts/python.exe',
|
f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py'
|
||||||
'finetune/prepare_buckets_latents.py',
|
)
|
||||||
image_folder,
|
run_cmd += f' {image_folder}'
|
||||||
'{}/meta_cap.json'.format(train_dir),
|
run_cmd += f' {train_dir}/{caption_metadata_filename}'
|
||||||
'{}/meta_lat.json'.format(train_dir),
|
run_cmd += f' {train_dir}/{latent_metadata_filename}'
|
||||||
pretrained_model_name_or_path,
|
run_cmd += f' {pretrained_model_name_or_path}'
|
||||||
'--batch_size',
|
run_cmd += f' --batch_size={batch_size}'
|
||||||
'4',
|
run_cmd += f' --max_resolution={max_resolution}'
|
||||||
'--max_resolution',
|
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
||||||
max_resolution,
|
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
||||||
'--mixed_precision',
|
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||||
mixed_precision,
|
if flip_aug:
|
||||||
'--full_path',
|
run_cmd += f' --flip_aug'
|
||||||
]
|
if full_path:
|
||||||
|
run_cmd += f' --full_path'
|
||||||
|
|
||||||
print(command)
|
print(run_cmd)
|
||||||
|
|
||||||
# Run the command
|
# Run the command
|
||||||
subprocess.run(command)
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
image_num = len(
|
image_num = len(
|
||||||
[f for f in os.listdir(image_folder) if f.endswith('.npz')]
|
[f for f in os.listdir(image_folder) if f.endswith('.npz')]
|
||||||
@ -270,11 +317,14 @@ def train_model(
|
|||||||
max_train_steps = int(
|
max_train_steps = int(
|
||||||
math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
|
math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Divide by two because flip augmentation create two copied of the source images
|
||||||
|
if flip_aug:
|
||||||
|
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
|
||||||
|
|
||||||
print(f'max_train_steps = {max_train_steps}')
|
print(f'max_train_steps = {max_train_steps}')
|
||||||
|
|
||||||
lr_warmup_steps = round(
|
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
||||||
float(int(lr_warmup) * int(max_train_steps) / 100)
|
|
||||||
)
|
|
||||||
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
||||||
|
|
||||||
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
|
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
|
||||||
@ -284,10 +334,14 @@ def train_model(
|
|||||||
run_cmd += ' --v_parameterization'
|
run_cmd += ' --v_parameterization'
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
run_cmd += ' --train_text_encoder'
|
run_cmd += ' --train_text_encoder'
|
||||||
|
if use_8bit_adam:
|
||||||
|
run_cmd += f' --use_8bit_adam'
|
||||||
|
if xformers:
|
||||||
|
run_cmd += f' --xformers'
|
||||||
run_cmd += (
|
run_cmd += (
|
||||||
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
||||||
)
|
)
|
||||||
run_cmd += f' --in_json={train_dir}/meta_lat.json'
|
run_cmd += f' --in_json={train_dir}/{latent_metadata_filename}'
|
||||||
run_cmd += f' --train_data_dir={image_folder}'
|
run_cmd += f' --train_data_dir={image_folder}'
|
||||||
run_cmd += f' --output_dir={output_dir}'
|
run_cmd += f' --output_dir={output_dir}'
|
||||||
if not logging_dir == '':
|
if not logging_dir == '':
|
||||||
@ -298,8 +352,6 @@ def train_model(
|
|||||||
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
||||||
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
||||||
run_cmd += f' --max_train_steps={max_train_steps}'
|
run_cmd += f' --max_train_steps={max_train_steps}'
|
||||||
run_cmd += f' --use_8bit_adam'
|
|
||||||
run_cmd += f' --xformers'
|
|
||||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||||
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
|
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
|
||||||
run_cmd += f' --seed={seed}'
|
run_cmd += f' --seed={seed}'
|
||||||
@ -389,9 +441,9 @@ def UI(username, password):
|
|||||||
interface = gr.Blocks(css=css)
|
interface = gr.Blocks(css=css)
|
||||||
|
|
||||||
with interface:
|
with interface:
|
||||||
with gr.Tab("Finetune"):
|
with gr.Tab('Finetune'):
|
||||||
finetune_tab()
|
finetune_tab()
|
||||||
with gr.Tab("Utilities"):
|
with gr.Tab('Utilities'):
|
||||||
utilities_tab(enable_dreambooth_tab=False)
|
utilities_tab(enable_dreambooth_tab=False)
|
||||||
|
|
||||||
# Show the interface
|
# Show the interface
|
||||||
@ -400,12 +452,11 @@ def UI(username, password):
|
|||||||
else:
|
else:
|
||||||
interface.launch()
|
interface.launch()
|
||||||
|
|
||||||
|
|
||||||
def finetune_tab():
|
def finetune_tab():
|
||||||
dummy_ft_true = gr.Label(value=True, visible=False)
|
dummy_ft_true = gr.Label(value=True, visible=False)
|
||||||
dummy_ft_false = gr.Label(value=False, visible=False)
|
dummy_ft_false = gr.Label(value=False, visible=False)
|
||||||
gr.Markdown(
|
gr.Markdown('Train a custom model using kohya finetune python code...')
|
||||||
'Train a custom model using kohya finetune python code...'
|
|
||||||
)
|
|
||||||
with gr.Accordion('Configuration file', open=False):
|
with gr.Accordion('Configuration file', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
button_open_config = gr.Button(
|
button_open_config = gr.Button(
|
||||||
@ -496,9 +547,7 @@ def finetune_tab():
|
|||||||
train_dir_folder = gr.Button(
|
train_dir_folder = gr.Button(
|
||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
train_dir_folder.click(
|
train_dir_folder.click(get_folder_path, outputs=train_dir_input)
|
||||||
get_folder_path, outputs=train_dir_input
|
|
||||||
)
|
|
||||||
|
|
||||||
image_folder_input = gr.Textbox(
|
image_folder_input = gr.Textbox(
|
||||||
label='Training Image folder',
|
label='Training Image folder',
|
||||||
@ -547,11 +596,31 @@ def finetune_tab():
|
|||||||
inputs=[output_dir_input],
|
inputs=[output_dir_input],
|
||||||
outputs=[output_dir_input],
|
outputs=[output_dir_input],
|
||||||
)
|
)
|
||||||
|
with gr.Tab('Dataset preparation'):
|
||||||
|
with gr.Row():
|
||||||
|
max_resolution_input = gr.Textbox(
|
||||||
|
label='Resolution (width,height)', value='512,512'
|
||||||
|
)
|
||||||
|
min_bucket_reso = gr.Textbox(
|
||||||
|
label='Min bucket resolution', value='256'
|
||||||
|
)
|
||||||
|
max_bucket_reso = gr.Textbox(
|
||||||
|
label='Max bucket resolution', value='1024'
|
||||||
|
)
|
||||||
|
batch_size = gr.Textbox(label='Batch size', value='1')
|
||||||
|
with gr.Accordion('Advanced parameters', open=False):
|
||||||
|
with gr.Row():
|
||||||
|
caption_metadata_filename = gr.Textbox(
|
||||||
|
label='Caption metadata filename', value='meta_cap.json'
|
||||||
|
)
|
||||||
|
latent_metadata_filename = gr.Textbox(
|
||||||
|
label='Latent metadata filename', value='meta_lat.json'
|
||||||
|
)
|
||||||
|
full_path = gr.Checkbox(label='Use full path', value=True)
|
||||||
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
with gr.Tab('Training parameters'):
|
with gr.Tab('Training parameters'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate_input = gr.Textbox(
|
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
||||||
label='Learning rate', value=1e-6
|
|
||||||
)
|
|
||||||
lr_scheduler_input = gr.Dropdown(
|
lr_scheduler_input = gr.Dropdown(
|
||||||
label='LR Scheduler',
|
label='LR Scheduler',
|
||||||
choices=[
|
choices=[
|
||||||
@ -606,11 +675,7 @@ def finetune_tab():
|
|||||||
label='Number of CPU threads per process',
|
label='Number of CPU threads per process',
|
||||||
value=os.cpu_count(),
|
value=os.cpu_count(),
|
||||||
)
|
)
|
||||||
with gr.Row():
|
|
||||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
seed_input = gr.Textbox(label='Seed', value=1234)
|
||||||
max_resolution_input = gr.Textbox(
|
|
||||||
label='Max resolution', value='512,512'
|
|
||||||
)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
caption_extention_input = gr.Textbox(
|
caption_extention_input = gr.Textbox(
|
||||||
label='Caption Extension',
|
label='Caption Extension',
|
||||||
@ -619,168 +684,74 @@ def finetune_tab():
|
|||||||
train_text_encoder_input = gr.Checkbox(
|
train_text_encoder_input = gr.Checkbox(
|
||||||
label='Train text encoder', value=True
|
label='Train text encoder', value=True
|
||||||
)
|
)
|
||||||
|
with gr.Accordion('Advanced parameters', open=False):
|
||||||
|
with gr.Row():
|
||||||
|
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||||
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
create_caption = gr.Checkbox(
|
create_caption = gr.Checkbox(
|
||||||
label='Generate caption database', value=True
|
label='Generate caption metadata', value=True
|
||||||
)
|
)
|
||||||
create_buckets = gr.Checkbox(
|
create_buckets = gr.Checkbox(
|
||||||
label='Generate image buckets', value=True
|
label='Generate image buckets metadata', value=True
|
||||||
)
|
)
|
||||||
|
|
||||||
button_run = gr.Button('Train model')
|
button_run = gr.Button('Train model')
|
||||||
|
|
||||||
button_run.click(
|
settings_list = [
|
||||||
train_model,
|
pretrained_model_name_or_path_input,
|
||||||
inputs=[
|
v2_input,
|
||||||
create_caption,
|
v_parameterization_input,
|
||||||
create_buckets,
|
train_dir_input,
|
||||||
pretrained_model_name_or_path_input,
|
image_folder_input,
|
||||||
v2_input,
|
output_dir_input,
|
||||||
v_parameterization_input,
|
logging_dir_input,
|
||||||
train_dir_input,
|
max_resolution_input,
|
||||||
image_folder_input,
|
min_bucket_reso,
|
||||||
output_dir_input,
|
max_bucket_reso,
|
||||||
logging_dir_input,
|
batch_size,
|
||||||
max_resolution_input,
|
flip_aug,
|
||||||
learning_rate_input,
|
caption_metadata_filename,
|
||||||
lr_scheduler_input,
|
latent_metadata_filename,
|
||||||
lr_warmup_input,
|
full_path,
|
||||||
dataset_repeats_input,
|
learning_rate_input,
|
||||||
train_batch_size_input,
|
lr_scheduler_input,
|
||||||
epoch_input,
|
lr_warmup_input,
|
||||||
save_every_n_epochs_input,
|
dataset_repeats_input,
|
||||||
mixed_precision_input,
|
train_batch_size_input,
|
||||||
save_precision_input,
|
epoch_input,
|
||||||
seed_input,
|
save_every_n_epochs_input,
|
||||||
num_cpu_threads_per_process_input,
|
mixed_precision_input,
|
||||||
train_text_encoder_input,
|
save_precision_input,
|
||||||
save_model_as_dropdown,
|
seed_input,
|
||||||
caption_extention_input,
|
num_cpu_threads_per_process_input,
|
||||||
],
|
train_text_encoder_input,
|
||||||
)
|
create_caption,
|
||||||
|
create_buckets,
|
||||||
|
save_model_as_dropdown,
|
||||||
|
caption_extention_input,
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
|
]
|
||||||
|
|
||||||
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
open_config_file,
|
open_config_file,
|
||||||
inputs=[
|
inputs=[config_file_name] + settings_list,
|
||||||
config_file_name,
|
outputs=[config_file_name] + settings_list,
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
train_dir_input,
|
|
||||||
image_folder_input,
|
|
||||||
output_dir_input,
|
|
||||||
logging_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
dataset_repeats_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
train_text_encoder_input,
|
|
||||||
create_buckets,
|
|
||||||
create_caption,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
caption_extention_input,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
config_file_name,
|
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
train_dir_input,
|
|
||||||
image_folder_input,
|
|
||||||
output_dir_input,
|
|
||||||
logging_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
dataset_repeats_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
train_text_encoder_input,
|
|
||||||
create_buckets,
|
|
||||||
create_caption,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
caption_extention_input,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
button_save_config.click(
|
button_save_config.click(
|
||||||
save_configuration,
|
save_configuration,
|
||||||
inputs=[
|
inputs=[dummy_ft_false, config_file_name] + settings_list,
|
||||||
dummy_ft_false,
|
|
||||||
config_file_name,
|
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
train_dir_input,
|
|
||||||
image_folder_input,
|
|
||||||
output_dir_input,
|
|
||||||
logging_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
dataset_repeats_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
train_text_encoder_input,
|
|
||||||
create_buckets,
|
|
||||||
create_caption,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
caption_extention_input,
|
|
||||||
],
|
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
button_save_as_config.click(
|
button_save_as_config.click(
|
||||||
save_configuration,
|
save_configuration,
|
||||||
inputs=[
|
inputs=[dummy_ft_true, config_file_name] + settings_list,
|
||||||
dummy_ft_true,
|
|
||||||
config_file_name,
|
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
train_dir_input,
|
|
||||||
image_folder_input,
|
|
||||||
output_dir_input,
|
|
||||||
logging_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
dataset_repeats_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
train_text_encoder_input,
|
|
||||||
create_buckets,
|
|
||||||
create_caption,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
caption_extention_input,
|
|
||||||
],
|
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
73
lora_gui.py
73
lora_gui.py
@ -17,7 +17,7 @@ from library.common_gui import (
|
|||||||
get_file_path,
|
get_file_path,
|
||||||
get_any_file_path,
|
get_any_file_path,
|
||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
color_aug_changed
|
color_aug_changed,
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
@ -65,7 +65,13 @@ def save_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
text_encoder_lr,
|
||||||
|
unet_lr,
|
||||||
|
network_dim,
|
||||||
|
lora_network_weights,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
@ -164,7 +170,13 @@ def open_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
text_encoder_lr,
|
||||||
|
unet_lr,
|
||||||
|
network_dim,
|
||||||
|
lora_network_weights,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
):
|
):
|
||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
@ -257,7 +269,13 @@ def train_model(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
text_encoder_lr,
|
||||||
|
unet_lr,
|
||||||
|
network_dim,
|
||||||
|
lora_network_weights,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
):
|
):
|
||||||
def save_inference_file(output_dir, v2, v_parameterization):
|
def save_inference_file(output_dir, v2, v_parameterization):
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
@ -296,11 +314,15 @@ def train_model(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# If string is empty set string to 0.
|
# If string is empty set string to 0.
|
||||||
if text_encoder_lr == '': text_encoder_lr = 0
|
if text_encoder_lr == '':
|
||||||
if unet_lr == '': unet_lr = 0
|
text_encoder_lr = 0
|
||||||
|
if unet_lr == '':
|
||||||
|
unet_lr = 0
|
||||||
|
|
||||||
if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
|
if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
|
||||||
msgbox('At least one Learning Rate value for "Text encoder" or "Unet" need to be provided')
|
msgbox(
|
||||||
|
'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get a list of all subfolders in train_data_dir
|
# Get a list of all subfolders in train_data_dir
|
||||||
@ -446,7 +468,6 @@ def train_model(
|
|||||||
if not lora_network_weights == '':
|
if not lora_network_weights == '':
|
||||||
run_cmd += f' --network_weights={lora_network_weights}'
|
run_cmd += f' --network_weights={lora_network_weights}'
|
||||||
|
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
@ -552,7 +573,9 @@ def lora_tab(
|
|||||||
):
|
):
|
||||||
dummy_db_true = gr.Label(value=True, visible=False)
|
dummy_db_true = gr.Label(value=True, visible=False)
|
||||||
dummy_db_false = gr.Label(value=False, visible=False)
|
dummy_db_false = gr.Label(value=False, visible=False)
|
||||||
gr.Markdown('Train a custom model using kohya train network LoRA python code...')
|
gr.Markdown(
|
||||||
|
'Train a custom model using kohya train network LoRA python code...'
|
||||||
|
)
|
||||||
with gr.Accordion('Configuration file', open=False):
|
with gr.Accordion('Configuration file', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
button_open_config = gr.Button('Open 📂', elem_id='open_folder')
|
button_open_config = gr.Button('Open 📂', elem_id='open_folder')
|
||||||
@ -729,8 +752,14 @@ def lora_tab(
|
|||||||
)
|
)
|
||||||
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_encoder_lr = gr.Textbox(label='Text Encoder learning rate', value=1e-6, placeholder='Optional')
|
text_encoder_lr = gr.Textbox(
|
||||||
unet_lr = gr.Textbox(label='Unet learning rate', value=1e-4, placeholder='Optional')
|
label='Text Encoder learning rate',
|
||||||
|
value=1e-6,
|
||||||
|
placeholder='Optional',
|
||||||
|
)
|
||||||
|
unet_lr = gr.Textbox(
|
||||||
|
label='Unet learning rate', value=1e-4, placeholder='Optional'
|
||||||
|
)
|
||||||
# network_train = gr.Dropdown(
|
# network_train = gr.Dropdown(
|
||||||
# label='Network to train',
|
# label='Network to train',
|
||||||
# choices=[
|
# choices=[
|
||||||
@ -747,7 +776,7 @@ def lora_tab(
|
|||||||
label='Network Dimension',
|
label='Network Dimension',
|
||||||
value=4,
|
value=4,
|
||||||
step=1,
|
step=1,
|
||||||
interactive=True
|
interactive=True,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_batch_size_input = gr.Slider(
|
train_batch_size_input = gr.Slider(
|
||||||
@ -837,10 +866,12 @@ def lora_tab(
|
|||||||
color_aug = gr.Checkbox(
|
color_aug = gr.Checkbox(
|
||||||
label='Color augmentation', value=False
|
label='Color augmentation', value=False
|
||||||
)
|
)
|
||||||
flip_aug = gr.Checkbox(
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
label='Flip augmentation', value=False
|
color_aug.change(
|
||||||
|
color_aug_changed,
|
||||||
|
inputs=[color_aug],
|
||||||
|
outputs=[cache_latent_input],
|
||||||
)
|
)
|
||||||
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resume = gr.Textbox(
|
resume = gr.Textbox(
|
||||||
label='Resume from saved training state',
|
label='Resume from saved training state',
|
||||||
@ -852,7 +883,9 @@ def lora_tab(
|
|||||||
label='Prior loss weight', value=1.0
|
label='Prior loss weight', value=1.0
|
||||||
)
|
)
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown('This section provide Dreambooth tools to help setup your dataset...')
|
gr.Markdown(
|
||||||
|
'This section provide Dreambooth tools to help setup your dataset...'
|
||||||
|
)
|
||||||
gradio_dreambooth_folder_creation_tab(
|
gradio_dreambooth_folder_creation_tab(
|
||||||
train_data_dir_input=train_data_dir_input,
|
train_data_dir_input=train_data_dir_input,
|
||||||
reg_data_dir_input=reg_data_dir_input,
|
reg_data_dir_input=reg_data_dir_input,
|
||||||
@ -895,7 +928,13 @@ def lora_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
prior_loss_weight,
|
||||||
|
text_encoder_lr,
|
||||||
|
unet_lr,
|
||||||
|
network_dim,
|
||||||
|
lora_network_weights,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
Loading…
Reference in New Issue
Block a user