Merge pull request #24 from bmaltais/dev
Add multiple options to fine_tune.py
This commit is contained in:
commit
0251bf3064
@ -30,6 +30,11 @@ Once you have created the LoRA network you can generate images via auto1111 by i
|
||||
|
||||
## Change history
|
||||
|
||||
* 2023/01/02 (v19.2) update:
|
||||
- Finetune, add xformers, 8bit adam, min bucket, max bucket, batch size and flip augmentation support for dataset preparation
|
||||
- Finetune, add "Dataset preparation" tab to group task specific options
|
||||
* 2023/01/01 (v19.2) update:
|
||||
- add support for color and flip augmentation to "Dreambooth LoRA"
|
||||
* 2023/01/01 (v19.1) update:
|
||||
- merge kohys_ss upstream code updates
|
||||
- rework Dreambooth LoRA GUI
|
||||
|
@ -17,6 +17,7 @@ from library.common_gui import (
|
||||
get_file_path,
|
||||
get_any_file_path,
|
||||
get_saveasfile_path,
|
||||
color_aug_changed,
|
||||
)
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
@ -66,6 +67,8 @@ def save_configuration(
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
):
|
||||
original_file_path = file_path
|
||||
|
||||
@ -118,6 +121,8 @@ def save_configuration(
|
||||
'save_state': save_state,
|
||||
'resume': resume,
|
||||
'prior_loss_weight': prior_loss_weight,
|
||||
'color_aug': color_aug,
|
||||
'flip_aug': flip_aug,
|
||||
}
|
||||
|
||||
# Save the data to the selected file
|
||||
@ -161,6 +166,8 @@ def open_configuration(
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
):
|
||||
|
||||
original_file_path = file_path
|
||||
@ -214,6 +221,8 @@ def open_configuration(
|
||||
my_data.get('save_state', save_state),
|
||||
my_data.get('resume', resume),
|
||||
my_data.get('prior_loss_weight', prior_loss_weight),
|
||||
my_data.get('color_aug', color_aug),
|
||||
my_data.get('flip_aug', flip_aug),
|
||||
)
|
||||
|
||||
|
||||
@ -250,6 +259,8 @@ def train_model(
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
):
|
||||
def save_inference_file(output_dir, v2, v_parameterization):
|
||||
# Copy inference model for v2 if required
|
||||
@ -377,6 +388,10 @@ def train_model(
|
||||
run_cmd += ' --shuffle_caption'
|
||||
if save_state:
|
||||
run_cmd += ' --save_state'
|
||||
if color_aug:
|
||||
run_cmd += ' --color_aug'
|
||||
if flip_aug:
|
||||
run_cmd += ' --flip_aug'
|
||||
run_cmd += (
|
||||
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
||||
)
|
||||
@ -762,6 +777,15 @@ def dreambooth_tab(
|
||||
save_state = gr.Checkbox(
|
||||
label='Save training state', value=False
|
||||
)
|
||||
color_aug = gr.Checkbox(
|
||||
label='Color augmentation', value=False
|
||||
)
|
||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[color_aug],
|
||||
outputs=[cache_latent_input],
|
||||
)
|
||||
with gr.Row():
|
||||
resume = gr.Textbox(
|
||||
label='Resume from saved training state',
|
||||
@ -773,7 +797,9 @@ def dreambooth_tab(
|
||||
label='Prior loss weight', value=1.0
|
||||
)
|
||||
with gr.Tab('Tools'):
|
||||
gr.Markdown('This section provide Dreambooth tools to help setup your dataset...')
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
)
|
||||
gradio_dreambooth_folder_creation_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
@ -820,6 +846,8 @@ def dreambooth_tab(
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
],
|
||||
outputs=[
|
||||
config_file_name,
|
||||
@ -855,6 +883,8 @@ def dreambooth_tab(
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
],
|
||||
)
|
||||
|
||||
@ -895,6 +925,8 @@ def dreambooth_tab(
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
],
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
@ -936,6 +968,8 @@ def dreambooth_tab(
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
],
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
@ -975,6 +1009,8 @@ def dreambooth_tab(
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
],
|
||||
)
|
||||
|
||||
|
341
finetune_gui.py
341
finetune_gui.py
@ -31,6 +31,13 @@ def save_configuration(
|
||||
output_dir,
|
||||
logging_dir,
|
||||
max_resolution,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
batch_size,
|
||||
flip_aug,
|
||||
caption_metadata_filename,
|
||||
latent_metadata_filename,
|
||||
full_path,
|
||||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
@ -43,10 +50,12 @@ def save_configuration(
|
||||
seed,
|
||||
num_cpu_threads_per_process,
|
||||
train_text_encoder,
|
||||
create_buckets,
|
||||
create_caption,
|
||||
create_buckets,
|
||||
save_model_as,
|
||||
caption_extension,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
):
|
||||
original_file_path = file_path
|
||||
|
||||
@ -75,6 +84,13 @@ def save_configuration(
|
||||
'output_dir': output_dir,
|
||||
'logging_dir': logging_dir,
|
||||
'max_resolution': max_resolution,
|
||||
'min_bucket_reso': min_bucket_reso,
|
||||
'max_bucket_reso': max_bucket_reso,
|
||||
'batch_size': batch_size,
|
||||
'flip_aug': flip_aug,
|
||||
'caption_metadata_filename': caption_metadata_filename,
|
||||
'latent_metadata_filename': latent_metadata_filename,
|
||||
'full_path': full_path,
|
||||
'learning_rate': learning_rate,
|
||||
'lr_scheduler': lr_scheduler,
|
||||
'lr_warmup': lr_warmup,
|
||||
@ -91,6 +107,8 @@ def save_configuration(
|
||||
'create_caption': create_caption,
|
||||
'save_model_as': save_model_as,
|
||||
'caption_extension': caption_extension,
|
||||
'use_8bit_adam': use_8bit_adam,
|
||||
'xformers': xformers,
|
||||
}
|
||||
|
||||
# Save the data to the selected file
|
||||
@ -110,6 +128,13 @@ def open_config_file(
|
||||
output_dir,
|
||||
logging_dir,
|
||||
max_resolution,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
batch_size,
|
||||
flip_aug,
|
||||
caption_metadata_filename,
|
||||
latent_metadata_filename,
|
||||
full_path,
|
||||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
@ -122,10 +147,12 @@ def open_config_file(
|
||||
seed,
|
||||
num_cpu_threads_per_process,
|
||||
train_text_encoder,
|
||||
create_buckets,
|
||||
create_caption,
|
||||
create_buckets,
|
||||
save_model_as,
|
||||
caption_extension,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
):
|
||||
original_file_path = file_path
|
||||
file_path = get_file_path(file_path)
|
||||
@ -152,6 +179,13 @@ def open_config_file(
|
||||
my_data.get('output_dir', output_dir),
|
||||
my_data.get('logging_dir', logging_dir),
|
||||
my_data.get('max_resolution', max_resolution),
|
||||
my_data.get('min_bucket_reso', min_bucket_reso),
|
||||
my_data.get('max_bucket_reso', max_bucket_reso),
|
||||
my_data.get('batch_size', batch_size),
|
||||
my_data.get('flip_aug', flip_aug),
|
||||
my_data.get('caption_metadata_filename', caption_metadata_filename),
|
||||
my_data.get('latent_metadata_filename', latent_metadata_filename),
|
||||
my_data.get('full_path', full_path),
|
||||
my_data.get('learning_rate', learning_rate),
|
||||
my_data.get('lr_scheduler', lr_scheduler),
|
||||
my_data.get('lr_warmup', lr_warmup),
|
||||
@ -170,12 +204,12 @@ def open_config_file(
|
||||
my_data.get('create_caption', create_caption),
|
||||
my_data.get('save_model_as', save_model_as),
|
||||
my_data.get('caption_extension', caption_extension),
|
||||
my_data.get('use_8bit_adam', use_8bit_adam),
|
||||
my_data.get('xformers', xformers),
|
||||
)
|
||||
|
||||
|
||||
def train_model(
|
||||
generate_caption_database,
|
||||
generate_image_buckets,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
@ -184,6 +218,13 @@ def train_model(
|
||||
output_dir,
|
||||
logging_dir,
|
||||
max_resolution,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
batch_size,
|
||||
flip_aug,
|
||||
caption_metadata_filename,
|
||||
latent_metadata_filename,
|
||||
full_path,
|
||||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
@ -196,8 +237,12 @@ def train_model(
|
||||
seed,
|
||||
num_cpu_threads_per_process,
|
||||
train_text_encoder,
|
||||
generate_caption_database,
|
||||
generate_image_buckets,
|
||||
save_model_as,
|
||||
caption_extension,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
):
|
||||
def save_inference_file(output_dir, v2, v_parameterization):
|
||||
# Copy inference model for v2 if required
|
||||
@ -227,8 +272,9 @@ def train_model(
|
||||
else:
|
||||
run_cmd += f' --caption_extension={caption_extension}'
|
||||
run_cmd += f' {image_folder}'
|
||||
run_cmd += f' {train_dir}/meta_cap.json'
|
||||
run_cmd += f' --full_path'
|
||||
run_cmd += f' {train_dir}/{caption_metadata_filename}'
|
||||
if full_path:
|
||||
run_cmd += f' --full_path'
|
||||
|
||||
print(run_cmd)
|
||||
|
||||
@ -237,26 +283,27 @@ def train_model(
|
||||
|
||||
# create images buckets
|
||||
if generate_image_buckets:
|
||||
command = [
|
||||
'./venv/Scripts/python.exe',
|
||||
'finetune/prepare_buckets_latents.py',
|
||||
image_folder,
|
||||
'{}/meta_cap.json'.format(train_dir),
|
||||
'{}/meta_lat.json'.format(train_dir),
|
||||
pretrained_model_name_or_path,
|
||||
'--batch_size',
|
||||
'4',
|
||||
'--max_resolution',
|
||||
max_resolution,
|
||||
'--mixed_precision',
|
||||
mixed_precision,
|
||||
'--full_path',
|
||||
]
|
||||
run_cmd = (
|
||||
f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py'
|
||||
)
|
||||
run_cmd += f' {image_folder}'
|
||||
run_cmd += f' {train_dir}/{caption_metadata_filename}'
|
||||
run_cmd += f' {train_dir}/{latent_metadata_filename}'
|
||||
run_cmd += f' {pretrained_model_name_or_path}'
|
||||
run_cmd += f' --batch_size={batch_size}'
|
||||
run_cmd += f' --max_resolution={max_resolution}'
|
||||
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
||||
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||
if flip_aug:
|
||||
run_cmd += f' --flip_aug'
|
||||
if full_path:
|
||||
run_cmd += f' --full_path'
|
||||
|
||||
print(command)
|
||||
print(run_cmd)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(command)
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
image_num = len(
|
||||
[f for f in os.listdir(image_folder) if f.endswith('.npz')]
|
||||
@ -270,11 +317,14 @@ def train_model(
|
||||
max_train_steps = int(
|
||||
math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
|
||||
)
|
||||
|
||||
# Divide by two because flip augmentation create two copied of the source images
|
||||
if flip_aug:
|
||||
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
|
||||
|
||||
print(f'max_train_steps = {max_train_steps}')
|
||||
|
||||
lr_warmup_steps = round(
|
||||
float(int(lr_warmup) * int(max_train_steps) / 100)
|
||||
)
|
||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
||||
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
||||
|
||||
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
|
||||
@ -284,10 +334,14 @@ def train_model(
|
||||
run_cmd += ' --v_parameterization'
|
||||
if train_text_encoder:
|
||||
run_cmd += ' --train_text_encoder'
|
||||
if use_8bit_adam:
|
||||
run_cmd += f' --use_8bit_adam'
|
||||
if xformers:
|
||||
run_cmd += f' --xformers'
|
||||
run_cmd += (
|
||||
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
||||
)
|
||||
run_cmd += f' --in_json={train_dir}/meta_lat.json'
|
||||
run_cmd += f' --in_json={train_dir}/{latent_metadata_filename}'
|
||||
run_cmd += f' --train_data_dir={image_folder}'
|
||||
run_cmd += f' --output_dir={output_dir}'
|
||||
if not logging_dir == '':
|
||||
@ -298,8 +352,6 @@ def train_model(
|
||||
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
||||
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
||||
run_cmd += f' --max_train_steps={max_train_steps}'
|
||||
run_cmd += f' --use_8bit_adam'
|
||||
run_cmd += f' --xformers'
|
||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
|
||||
run_cmd += f' --seed={seed}'
|
||||
@ -389,9 +441,9 @@ def UI(username, password):
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab("Finetune"):
|
||||
with gr.Tab('Finetune'):
|
||||
finetune_tab()
|
||||
with gr.Tab("Utilities"):
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(enable_dreambooth_tab=False)
|
||||
|
||||
# Show the interface
|
||||
@ -400,12 +452,11 @@ def UI(username, password):
|
||||
else:
|
||||
interface.launch()
|
||||
|
||||
|
||||
def finetune_tab():
|
||||
dummy_ft_true = gr.Label(value=True, visible=False)
|
||||
dummy_ft_false = gr.Label(value=False, visible=False)
|
||||
gr.Markdown(
|
||||
'Train a custom model using kohya finetune python code...'
|
||||
)
|
||||
gr.Markdown('Train a custom model using kohya finetune python code...')
|
||||
with gr.Accordion('Configuration file', open=False):
|
||||
with gr.Row():
|
||||
button_open_config = gr.Button(
|
||||
@ -496,9 +547,7 @@ def finetune_tab():
|
||||
train_dir_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
train_dir_folder.click(
|
||||
get_folder_path, outputs=train_dir_input
|
||||
)
|
||||
train_dir_folder.click(get_folder_path, outputs=train_dir_input)
|
||||
|
||||
image_folder_input = gr.Textbox(
|
||||
label='Training Image folder',
|
||||
@ -547,11 +596,31 @@ def finetune_tab():
|
||||
inputs=[output_dir_input],
|
||||
outputs=[output_dir_input],
|
||||
)
|
||||
with gr.Tab('Dataset preparation'):
|
||||
with gr.Row():
|
||||
max_resolution_input = gr.Textbox(
|
||||
label='Resolution (width,height)', value='512,512'
|
||||
)
|
||||
min_bucket_reso = gr.Textbox(
|
||||
label='Min bucket resolution', value='256'
|
||||
)
|
||||
max_bucket_reso = gr.Textbox(
|
||||
label='Max bucket resolution', value='1024'
|
||||
)
|
||||
batch_size = gr.Textbox(label='Batch size', value='1')
|
||||
with gr.Accordion('Advanced parameters', open=False):
|
||||
with gr.Row():
|
||||
caption_metadata_filename = gr.Textbox(
|
||||
label='Caption metadata filename', value='meta_cap.json'
|
||||
)
|
||||
latent_metadata_filename = gr.Textbox(
|
||||
label='Latent metadata filename', value='meta_lat.json'
|
||||
)
|
||||
full_path = gr.Checkbox(label='Use full path', value=True)
|
||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||
with gr.Tab('Training parameters'):
|
||||
with gr.Row():
|
||||
learning_rate_input = gr.Textbox(
|
||||
label='Learning rate', value=1e-6
|
||||
)
|
||||
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
||||
lr_scheduler_input = gr.Dropdown(
|
||||
label='LR Scheduler',
|
||||
choices=[
|
||||
@ -606,11 +675,7 @@ def finetune_tab():
|
||||
label='Number of CPU threads per process',
|
||||
value=os.cpu_count(),
|
||||
)
|
||||
with gr.Row():
|
||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
||||
max_resolution_input = gr.Textbox(
|
||||
label='Max resolution', value='512,512'
|
||||
)
|
||||
with gr.Row():
|
||||
caption_extention_input = gr.Textbox(
|
||||
label='Caption Extension',
|
||||
@ -619,168 +684,74 @@ def finetune_tab():
|
||||
train_text_encoder_input = gr.Checkbox(
|
||||
label='Train text encoder', value=True
|
||||
)
|
||||
with gr.Accordion('Advanced parameters', open=False):
|
||||
with gr.Row():
|
||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
create_caption = gr.Checkbox(
|
||||
label='Generate caption database', value=True
|
||||
label='Generate caption metadata', value=True
|
||||
)
|
||||
create_buckets = gr.Checkbox(
|
||||
label='Generate image buckets', value=True
|
||||
label='Generate image buckets metadata', value=True
|
||||
)
|
||||
|
||||
button_run = gr.Button('Train model')
|
||||
|
||||
button_run.click(
|
||||
train_model,
|
||||
inputs=[
|
||||
create_caption,
|
||||
create_buckets,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
train_dir_input,
|
||||
image_folder_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
dataset_repeats_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
train_text_encoder_input,
|
||||
save_model_as_dropdown,
|
||||
caption_extention_input,
|
||||
],
|
||||
)
|
||||
settings_list = [
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
train_dir_input,
|
||||
image_folder_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
max_resolution_input,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
batch_size,
|
||||
flip_aug,
|
||||
caption_metadata_filename,
|
||||
latent_metadata_filename,
|
||||
full_path,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
dataset_repeats_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
train_text_encoder_input,
|
||||
create_caption,
|
||||
create_buckets,
|
||||
save_model_as_dropdown,
|
||||
caption_extention_input,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
]
|
||||
|
||||
button_run.click(train_model, inputs=settings_list)
|
||||
|
||||
button_open_config.click(
|
||||
open_config_file,
|
||||
inputs=[
|
||||
config_file_name,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
train_dir_input,
|
||||
image_folder_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
dataset_repeats_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
train_text_encoder_input,
|
||||
create_buckets,
|
||||
create_caption,
|
||||
save_model_as_dropdown,
|
||||
caption_extention_input,
|
||||
],
|
||||
outputs=[
|
||||
config_file_name,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
train_dir_input,
|
||||
image_folder_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
dataset_repeats_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
train_text_encoder_input,
|
||||
create_buckets,
|
||||
create_caption,
|
||||
save_model_as_dropdown,
|
||||
caption_extention_input,
|
||||
],
|
||||
inputs=[config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
)
|
||||
|
||||
button_save_config.click(
|
||||
save_configuration,
|
||||
inputs=[
|
||||
dummy_ft_false,
|
||||
config_file_name,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
train_dir_input,
|
||||
image_folder_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
dataset_repeats_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
train_text_encoder_input,
|
||||
create_buckets,
|
||||
create_caption,
|
||||
save_model_as_dropdown,
|
||||
caption_extention_input,
|
||||
],
|
||||
inputs=[dummy_ft_false, config_file_name] + settings_list,
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
|
||||
button_save_as_config.click(
|
||||
save_configuration,
|
||||
inputs=[
|
||||
dummy_ft_true,
|
||||
config_file_name,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
train_dir_input,
|
||||
image_folder_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
dataset_repeats_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
train_text_encoder_input,
|
||||
create_buckets,
|
||||
create_caption,
|
||||
save_model_as_dropdown,
|
||||
caption_extention_input,
|
||||
],
|
||||
inputs=[dummy_ft_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
from tkinter import filedialog, Tk
|
||||
import os
|
||||
import gradio as gr
|
||||
from easygui import msgbox
|
||||
|
||||
|
||||
def get_file_path(file_path='', defaultextension='.json'):
|
||||
@ -107,3 +109,10 @@ def add_pre_postfix(
|
||||
f.seek(0, 0)
|
||||
f.write(f'{prefix}{content}{postfix}')
|
||||
f.close()
|
||||
|
||||
def color_aug_changed(color_aug):
|
||||
if color_aug:
|
||||
msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...')
|
||||
return gr.Checkbox.update(value=False, interactive=False)
|
||||
else:
|
||||
return gr.Checkbox.update(value=True, interactive=True)
|
91
lora_gui.py
91
lora_gui.py
@ -17,6 +17,7 @@ from library.common_gui import (
|
||||
get_file_path,
|
||||
get_any_file_path,
|
||||
get_saveasfile_path,
|
||||
color_aug_changed,
|
||||
)
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
@ -64,7 +65,13 @@ def save_configuration(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
||||
prior_loss_weight,
|
||||
text_encoder_lr,
|
||||
unet_lr,
|
||||
network_dim,
|
||||
lora_network_weights,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
):
|
||||
original_file_path = file_path
|
||||
|
||||
@ -120,6 +127,8 @@ def save_configuration(
|
||||
'unet_lr': unet_lr,
|
||||
'network_dim': network_dim,
|
||||
'lora_network_weights': lora_network_weights,
|
||||
'color_aug': color_aug,
|
||||
'flip_aug': flip_aug,
|
||||
}
|
||||
|
||||
# Save the data to the selected file
|
||||
@ -161,7 +170,13 @@ def open_configuration(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
||||
prior_loss_weight,
|
||||
text_encoder_lr,
|
||||
unet_lr,
|
||||
network_dim,
|
||||
lora_network_weights,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
):
|
||||
|
||||
original_file_path = file_path
|
||||
@ -218,6 +233,8 @@ def open_configuration(
|
||||
my_data.get('unet_lr', unet_lr),
|
||||
my_data.get('network_dim', network_dim),
|
||||
my_data.get('lora_network_weights', lora_network_weights),
|
||||
my_data.get('color_aug', color_aug),
|
||||
my_data.get('flip_aug', flip_aug),
|
||||
)
|
||||
|
||||
|
||||
@ -252,7 +269,13 @@ def train_model(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
||||
prior_loss_weight,
|
||||
text_encoder_lr,
|
||||
unet_lr,
|
||||
network_dim,
|
||||
lora_network_weights,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
):
|
||||
def save_inference_file(output_dir, v2, v_parameterization):
|
||||
# Copy inference model for v2 if required
|
||||
@ -289,13 +312,17 @@ def train_model(
|
||||
if output_dir == '':
|
||||
msgbox('Output folder path is missing')
|
||||
return
|
||||
|
||||
|
||||
# If string is empty set string to 0.
|
||||
if text_encoder_lr == '': text_encoder_lr = 0
|
||||
if unet_lr == '': unet_lr = 0
|
||||
|
||||
if text_encoder_lr == '':
|
||||
text_encoder_lr = 0
|
||||
if unet_lr == '':
|
||||
unet_lr = 0
|
||||
|
||||
if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
|
||||
msgbox('At least one Learning Rate value for "Text encoder" or "Unet" need to be provided')
|
||||
msgbox(
|
||||
'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
|
||||
)
|
||||
return
|
||||
|
||||
# Get a list of all subfolders in train_data_dir
|
||||
@ -388,6 +415,10 @@ def train_model(
|
||||
run_cmd += ' --shuffle_caption'
|
||||
if save_state:
|
||||
run_cmd += ' --save_state'
|
||||
if color_aug:
|
||||
run_cmd += ' --color_aug'
|
||||
if flip_aug:
|
||||
run_cmd += ' --flip_aug'
|
||||
run_cmd += (
|
||||
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
||||
)
|
||||
@ -436,7 +467,6 @@ def train_model(
|
||||
run_cmd += f' --network_dim={network_dim}'
|
||||
if not lora_network_weights == '':
|
||||
run_cmd += f' --network_weights={lora_network_weights}'
|
||||
|
||||
|
||||
print(run_cmd)
|
||||
# Run the command
|
||||
@ -543,7 +573,9 @@ def lora_tab(
|
||||
):
|
||||
dummy_db_true = gr.Label(value=True, visible=False)
|
||||
dummy_db_false = gr.Label(value=False, visible=False)
|
||||
gr.Markdown('Train a custom model using kohya train network LoRA python code...')
|
||||
gr.Markdown(
|
||||
'Train a custom model using kohya train network LoRA python code...'
|
||||
)
|
||||
with gr.Accordion('Configuration file', open=False):
|
||||
with gr.Row():
|
||||
button_open_config = gr.Button('Open 📂', elem_id='open_folder')
|
||||
@ -606,7 +638,7 @@ def lora_tab(
|
||||
],
|
||||
value='same as source model',
|
||||
)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
v2_input = gr.Checkbox(label='v2', value=True)
|
||||
v_parameterization_input = gr.Checkbox(
|
||||
@ -720,8 +752,14 @@ def lora_tab(
|
||||
)
|
||||
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
||||
with gr.Row():
|
||||
text_encoder_lr = gr.Textbox(label='Text Encoder learning rate', value=1e-6, placeholder='Optional')
|
||||
unet_lr = gr.Textbox(label='Unet learning rate', value=1e-4, placeholder='Optional')
|
||||
text_encoder_lr = gr.Textbox(
|
||||
label='Text Encoder learning rate',
|
||||
value=1e-6,
|
||||
placeholder='Optional',
|
||||
)
|
||||
unet_lr = gr.Textbox(
|
||||
label='Unet learning rate', value=1e-4, placeholder='Optional'
|
||||
)
|
||||
# network_train = gr.Dropdown(
|
||||
# label='Network to train',
|
||||
# choices=[
|
||||
@ -738,7 +776,7 @@ def lora_tab(
|
||||
label='Network Dimension',
|
||||
value=4,
|
||||
step=1,
|
||||
interactive=True
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
train_batch_size_input = gr.Slider(
|
||||
@ -825,6 +863,15 @@ def lora_tab(
|
||||
save_state = gr.Checkbox(
|
||||
label='Save training state', value=False
|
||||
)
|
||||
color_aug = gr.Checkbox(
|
||||
label='Color augmentation', value=False
|
||||
)
|
||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[color_aug],
|
||||
outputs=[cache_latent_input],
|
||||
)
|
||||
with gr.Row():
|
||||
resume = gr.Textbox(
|
||||
label='Resume from saved training state',
|
||||
@ -836,7 +883,9 @@ def lora_tab(
|
||||
label='Prior loss weight', value=1.0
|
||||
)
|
||||
with gr.Tab('Tools'):
|
||||
gr.Markdown('This section provide Dreambooth tools to help setup your dataset...')
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
)
|
||||
gradio_dreambooth_folder_creation_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
@ -846,7 +895,7 @@ def lora_tab(
|
||||
gradio_dataset_balancing_tab()
|
||||
|
||||
button_run = gr.Button('Train model')
|
||||
|
||||
|
||||
settings_list = [
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
@ -879,7 +928,13 @@ def lora_tab(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
||||
prior_loss_weight,
|
||||
text_encoder_lr,
|
||||
unet_lr,
|
||||
network_dim,
|
||||
lora_network_weights,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
@ -902,7 +957,7 @@ def lora_tab(
|
||||
|
||||
button_run.click(
|
||||
train_model,
|
||||
inputs=settings_list,
|
||||
inputs=settings_list,
|
||||
)
|
||||
|
||||
return (
|
||||
|
Loading…
x
Reference in New Issue
Block a user