add support for color and flip augmentation to "Dreambooth LoRA"

This commit is contained in:
bmaltais 2023-01-01 22:43:44 -05:00
parent bfa590b313
commit 1d460a09fd
4 changed files with 56 additions and 13 deletions

View File

@ -30,6 +30,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i
## Change history
* 2023/01/01 (v19.2) update:
- add support for color and flip augmentation to "Dreambooth LoRA"
* 2023/01/01 (v19.1) update:
- merge kohys_ss upstream code updates
- rework Dreambooth LoRA GUI

View File

@ -17,6 +17,7 @@ from library.common_gui import (
get_file_path,
get_any_file_path,
get_saveasfile_path,
color_aug_changed
)
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
@ -65,7 +66,7 @@ def save_configuration(
shuffle_caption,
save_state,
resume,
prior_loss_weight,
prior_loss_weight, color_aug, flip_aug
):
original_file_path = file_path
@ -118,6 +119,8 @@ def save_configuration(
'save_state': save_state,
'resume': resume,
'prior_loss_weight': prior_loss_weight,
'color_aug': color_aug,
'flip_aug': flip_aug,
}
# Save the data to the selected file
@ -160,7 +163,7 @@ def open_configuration(
shuffle_caption,
save_state,
resume,
prior_loss_weight,
prior_loss_weight, color_aug, flip_aug
):
original_file_path = file_path
@ -214,6 +217,8 @@ def open_configuration(
my_data.get('save_state', save_state),
my_data.get('resume', resume),
my_data.get('prior_loss_weight', prior_loss_weight),
my_data.get('color_aug', color_aug),
my_data.get('flip_aug', flip_aug),
)
@ -249,7 +254,7 @@ def train_model(
shuffle_caption,
save_state,
resume,
prior_loss_weight,
prior_loss_weight, color_aug, flip_aug
):
def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required
@ -377,6 +382,10 @@ def train_model(
run_cmd += ' --shuffle_caption'
if save_state:
run_cmd += ' --save_state'
if color_aug:
run_cmd += ' --color_aug'
if flip_aug:
run_cmd += ' --flip_aug'
run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
)
@ -505,7 +514,7 @@ def UI(username, password):
interface.launch(auth=(username, password))
else:
interface.launch()
def dreambooth_tab(
train_data_dir_input=gr.Textbox(),
@ -762,6 +771,13 @@ def dreambooth_tab(
save_state = gr.Checkbox(
label='Save training state', value=False
)
color_aug = gr.Checkbox(
label='Color augmentation', value=False
)
flip_aug = gr.Checkbox(
label='Flip augmentation', value=False
)
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
with gr.Row():
resume = gr.Textbox(
label='Resume from saved training state',
@ -819,7 +835,7 @@ def dreambooth_tab(
shuffle_caption,
save_state,
resume,
prior_loss_weight,
prior_loss_weight, color_aug, flip_aug
],
outputs=[
config_file_name,
@ -854,7 +870,7 @@ def dreambooth_tab(
shuffle_caption,
save_state,
resume,
prior_loss_weight,
prior_loss_weight, color_aug, flip_aug
],
)
@ -894,7 +910,7 @@ def dreambooth_tab(
shuffle_caption,
save_state,
resume,
prior_loss_weight,
prior_loss_weight, color_aug, flip_aug
],
outputs=[config_file_name],
)
@ -935,7 +951,7 @@ def dreambooth_tab(
shuffle_caption,
save_state,
resume,
prior_loss_weight,
prior_loss_weight, color_aug, flip_aug
],
outputs=[config_file_name],
)
@ -974,7 +990,7 @@ def dreambooth_tab(
shuffle_caption,
save_state,
resume,
prior_loss_weight,
prior_loss_weight, color_aug, flip_aug
],
)

View File

@ -1,5 +1,7 @@
from tkinter import filedialog, Tk
import os
import gradio as gr
from easygui import msgbox
def get_file_path(file_path='', defaultextension='.json'):
@ -107,3 +109,10 @@ def add_pre_postfix(
f.seek(0, 0)
f.write(f'{prefix}{content}{postfix}')
f.close()
def color_aug_changed(color_aug):
if color_aug:
msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...')
return gr.Checkbox.update(value=False, interactive=False)
else:
return gr.Checkbox.update(value=True, interactive=True)

View File

@ -17,6 +17,7 @@ from library.common_gui import (
get_file_path,
get_any_file_path,
get_saveasfile_path,
color_aug_changed
)
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
@ -64,7 +65,7 @@ def save_configuration(
shuffle_caption,
save_state,
resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
):
original_file_path = file_path
@ -120,6 +121,8 @@ def save_configuration(
'unet_lr': unet_lr,
'network_dim': network_dim,
'lora_network_weights': lora_network_weights,
'color_aug': color_aug,
'flip_aug': flip_aug,
}
# Save the data to the selected file
@ -161,7 +164,7 @@ def open_configuration(
shuffle_caption,
save_state,
resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
):
original_file_path = file_path
@ -218,6 +221,8 @@ def open_configuration(
my_data.get('unet_lr', unet_lr),
my_data.get('network_dim', network_dim),
my_data.get('lora_network_weights', lora_network_weights),
my_data.get('color_aug', color_aug),
my_data.get('flip_aug', flip_aug),
)
@ -252,7 +257,7 @@ def train_model(
shuffle_caption,
save_state,
resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
):
def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required
@ -388,6 +393,10 @@ def train_model(
run_cmd += ' --shuffle_caption'
if save_state:
run_cmd += ' --save_state'
if color_aug:
run_cmd += ' --color_aug'
if flip_aug:
run_cmd += ' --flip_aug'
run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
)
@ -825,6 +834,13 @@ def lora_tab(
save_state = gr.Checkbox(
label='Save training state', value=False
)
color_aug = gr.Checkbox(
label='Color augmentation', value=False
)
flip_aug = gr.Checkbox(
label='Flip augmentation', value=False
)
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
with gr.Row():
resume = gr.Textbox(
label='Resume from saved training state',
@ -879,7 +895,7 @@ def lora_tab(
shuffle_caption,
save_state,
resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
]
button_open_config.click(