add support for color and flip augmentation to "Dreambooth LoRA"

This commit is contained in:
bmaltais 2023-01-01 22:43:44 -05:00
parent bfa590b313
commit 1d460a09fd
4 changed files with 56 additions and 13 deletions

View File

@ -30,6 +30,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i
## Change history ## Change history
* 2023/01/01 (v19.2) update:
- add support for color and flip augmentation to "Dreambooth LoRA"
* 2023/01/01 (v19.1) update: * 2023/01/01 (v19.1) update:
- merge kohys_ss upstream code updates - merge kohys_ss upstream code updates
- rework Dreambooth LoRA GUI - rework Dreambooth LoRA GUI

View File

@ -17,6 +17,7 @@ from library.common_gui import (
get_file_path, get_file_path,
get_any_file_path, get_any_file_path,
get_saveasfile_path, get_saveasfile_path,
color_aug_changed
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -65,7 +66,7 @@ def save_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, color_aug, flip_aug
): ):
original_file_path = file_path original_file_path = file_path
@ -118,6 +119,8 @@ def save_configuration(
'save_state': save_state, 'save_state': save_state,
'resume': resume, 'resume': resume,
'prior_loss_weight': prior_loss_weight, 'prior_loss_weight': prior_loss_weight,
'color_aug': color_aug,
'flip_aug': flip_aug,
} }
# Save the data to the selected file # Save the data to the selected file
@ -160,7 +163,7 @@ def open_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, color_aug, flip_aug
): ):
original_file_path = file_path original_file_path = file_path
@ -214,6 +217,8 @@ def open_configuration(
my_data.get('save_state', save_state), my_data.get('save_state', save_state),
my_data.get('resume', resume), my_data.get('resume', resume),
my_data.get('prior_loss_weight', prior_loss_weight), my_data.get('prior_loss_weight', prior_loss_weight),
my_data.get('color_aug', color_aug),
my_data.get('flip_aug', flip_aug),
) )
@ -249,7 +254,7 @@ def train_model(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, color_aug, flip_aug
): ):
def save_inference_file(output_dir, v2, v_parameterization): def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required # Copy inference model for v2 if required
@ -377,6 +382,10 @@ def train_model(
run_cmd += ' --shuffle_caption' run_cmd += ' --shuffle_caption'
if save_state: if save_state:
run_cmd += ' --save_state' run_cmd += ' --save_state'
if color_aug:
run_cmd += ' --color_aug'
if flip_aug:
run_cmd += ' --flip_aug'
run_cmd += ( run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
) )
@ -762,6 +771,13 @@ def dreambooth_tab(
save_state = gr.Checkbox( save_state = gr.Checkbox(
label='Save training state', value=False label='Save training state', value=False
) )
color_aug = gr.Checkbox(
label='Color augmentation', value=False
)
flip_aug = gr.Checkbox(
label='Flip augmentation', value=False
)
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
with gr.Row(): with gr.Row():
resume = gr.Textbox( resume = gr.Textbox(
label='Resume from saved training state', label='Resume from saved training state',
@ -819,7 +835,7 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, color_aug, flip_aug
], ],
outputs=[ outputs=[
config_file_name, config_file_name,
@ -854,7 +870,7 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, color_aug, flip_aug
], ],
) )
@ -894,7 +910,7 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, color_aug, flip_aug
], ],
outputs=[config_file_name], outputs=[config_file_name],
) )
@ -935,7 +951,7 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, color_aug, flip_aug
], ],
outputs=[config_file_name], outputs=[config_file_name],
) )
@ -974,7 +990,7 @@ def dreambooth_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, color_aug, flip_aug
], ],
) )

View File

@ -1,5 +1,7 @@
from tkinter import filedialog, Tk from tkinter import filedialog, Tk
import os import os
import gradio as gr
from easygui import msgbox
def get_file_path(file_path='', defaultextension='.json'): def get_file_path(file_path='', defaultextension='.json'):
@ -107,3 +109,10 @@ def add_pre_postfix(
f.seek(0, 0) f.seek(0, 0)
f.write(f'{prefix}{content}{postfix}') f.write(f'{prefix}{content}{postfix}')
f.close() f.close()
def color_aug_changed(color_aug):
if color_aug:
msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...')
return gr.Checkbox.update(value=False, interactive=False)
else:
return gr.Checkbox.update(value=True, interactive=True)

View File

@ -17,6 +17,7 @@ from library.common_gui import (
get_file_path, get_file_path,
get_any_file_path, get_any_file_path,
get_saveasfile_path, get_saveasfile_path,
color_aug_changed
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -64,7 +65,7 @@ def save_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
): ):
original_file_path = file_path original_file_path = file_path
@ -120,6 +121,8 @@ def save_configuration(
'unet_lr': unet_lr, 'unet_lr': unet_lr,
'network_dim': network_dim, 'network_dim': network_dim,
'lora_network_weights': lora_network_weights, 'lora_network_weights': lora_network_weights,
'color_aug': color_aug,
'flip_aug': flip_aug,
} }
# Save the data to the selected file # Save the data to the selected file
@ -161,7 +164,7 @@ def open_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
): ):
original_file_path = file_path original_file_path = file_path
@ -218,6 +221,8 @@ def open_configuration(
my_data.get('unet_lr', unet_lr), my_data.get('unet_lr', unet_lr),
my_data.get('network_dim', network_dim), my_data.get('network_dim', network_dim),
my_data.get('lora_network_weights', lora_network_weights), my_data.get('lora_network_weights', lora_network_weights),
my_data.get('color_aug', color_aug),
my_data.get('flip_aug', flip_aug),
) )
@ -252,7 +257,7 @@ def train_model(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
): ):
def save_inference_file(output_dir, v2, v_parameterization): def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required # Copy inference model for v2 if required
@ -388,6 +393,10 @@ def train_model(
run_cmd += ' --shuffle_caption' run_cmd += ' --shuffle_caption'
if save_state: if save_state:
run_cmd += ' --save_state' run_cmd += ' --save_state'
if color_aug:
run_cmd += ' --color_aug'
if flip_aug:
run_cmd += ' --flip_aug'
run_cmd += ( run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
) )
@ -825,6 +834,13 @@ def lora_tab(
save_state = gr.Checkbox( save_state = gr.Checkbox(
label='Save training state', value=False label='Save training state', value=False
) )
color_aug = gr.Checkbox(
label='Color augmentation', value=False
)
flip_aug = gr.Checkbox(
label='Flip augmentation', value=False
)
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
with gr.Row(): with gr.Row():
resume = gr.Textbox( resume = gr.Textbox(
label='Resume from saved training state', label='Resume from saved training state',
@ -879,7 +895,7 @@ def lora_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
] ]
button_open_config.click( button_open_config.click(