From 1d460a09fdb4ad5f1f7122b463a3b76382c308c5 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 1 Jan 2023 22:43:44 -0500 Subject: [PATCH] add support for color and flip augmentation to "Dreambooth LoRA" --- README.md | 2 ++ dreambooth_gui.py | 34 +++++++++++++++++++++++++--------- library/common_gui.py | 9 +++++++++ lora_gui.py | 24 ++++++++++++++++++++---- 4 files changed, 56 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index b0e1a56..78f8c1e 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i ## Change history +* 2023/01/01 (v19.2) update: + - add support for color and flip augmentation to "Dreambooth LoRA" * 2023/01/01 (v19.1) update: - merge kohys_ss upstream code updates - rework Dreambooth LoRA GUI diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 233e47d..f630d70 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -17,6 +17,7 @@ from library.common_gui import ( get_file_path, get_any_file_path, get_saveasfile_path, + color_aug_changed ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -65,7 +66,7 @@ def save_configuration( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, color_aug, flip_aug ): original_file_path = file_path @@ -118,6 +119,8 @@ def save_configuration( 'save_state': save_state, 'resume': resume, 'prior_loss_weight': prior_loss_weight, + 'color_aug': color_aug, + 'flip_aug': flip_aug, } # Save the data to the selected file @@ -160,7 +163,7 @@ def open_configuration( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, color_aug, flip_aug ): original_file_path = file_path @@ -214,6 +217,8 @@ def open_configuration( my_data.get('save_state', save_state), my_data.get('resume', resume), my_data.get('prior_loss_weight', prior_loss_weight), + my_data.get('color_aug', color_aug), + my_data.get('flip_aug', flip_aug), ) @@ -249,7 +254,7 @@ def train_model( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, color_aug, flip_aug ): def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required @@ -377,6 +382,10 @@ def train_model( run_cmd += ' --shuffle_caption' if save_state: run_cmd += ' --save_state' + if color_aug: + run_cmd += ' --color_aug' + if flip_aug: + run_cmd += ' --flip_aug' run_cmd += ( f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' ) @@ -505,7 +514,7 @@ def UI(username, password): interface.launch(auth=(username, password)) else: interface.launch() - + def dreambooth_tab( train_data_dir_input=gr.Textbox(), @@ -762,6 +771,13 @@ def dreambooth_tab( save_state = gr.Checkbox( label='Save training state', value=False ) + color_aug = gr.Checkbox( + label='Color augmentation', value=False + ) + flip_aug = gr.Checkbox( + label='Flip augmentation', value=False + ) + color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input]) with gr.Row(): resume = gr.Textbox( label='Resume from saved training state', @@ -819,7 +835,7 @@ def dreambooth_tab( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, color_aug, flip_aug ], outputs=[ config_file_name, @@ -854,7 +870,7 @@ def dreambooth_tab( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, color_aug, flip_aug ], ) @@ -894,7 +910,7 @@ def dreambooth_tab( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, color_aug, flip_aug ], outputs=[config_file_name], ) @@ -935,7 +951,7 @@ def dreambooth_tab( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, color_aug, flip_aug ], outputs=[config_file_name], ) @@ -974,7 +990,7 @@ def dreambooth_tab( shuffle_caption, save_state, resume, - prior_loss_weight, + prior_loss_weight, color_aug, flip_aug ], ) diff --git a/library/common_gui.py b/library/common_gui.py index ae1e647..c30c0d3 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -1,5 +1,7 @@ from tkinter import filedialog, Tk import os +import gradio as gr +from easygui import msgbox def get_file_path(file_path='', defaultextension='.json'): @@ -107,3 +109,10 @@ def add_pre_postfix( f.seek(0, 0) f.write(f'{prefix}{content}{postfix}') f.close() + +def color_aug_changed(color_aug): + if color_aug: + msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...') + return gr.Checkbox.update(value=False, interactive=False) + else: + return gr.Checkbox.update(value=True, interactive=True) \ No newline at end of file diff --git a/lora_gui.py b/lora_gui.py index e8bed70..ef8a8cc 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -17,6 +17,7 @@ from library.common_gui import ( get_file_path, get_any_file_path, get_saveasfile_path, + color_aug_changed ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -64,7 +65,7 @@ def save_configuration( shuffle_caption, save_state, resume, - prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights + prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug ): original_file_path = file_path @@ -120,6 +121,8 @@ def save_configuration( 'unet_lr': unet_lr, 'network_dim': network_dim, 'lora_network_weights': lora_network_weights, + 'color_aug': color_aug, + 'flip_aug': flip_aug, } # Save the data to the selected file @@ -161,7 +164,7 @@ def open_configuration( shuffle_caption, save_state, resume, - prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights + prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug ): original_file_path = file_path @@ -218,6 +221,8 @@ def open_configuration( my_data.get('unet_lr', unet_lr), my_data.get('network_dim', network_dim), my_data.get('lora_network_weights', lora_network_weights), + my_data.get('color_aug', color_aug), + my_data.get('flip_aug', flip_aug), ) @@ -252,7 +257,7 @@ def train_model( shuffle_caption, save_state, resume, - prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights + prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug ): def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required @@ -388,6 +393,10 @@ def train_model( run_cmd += ' --shuffle_caption' if save_state: run_cmd += ' --save_state' + if color_aug: + run_cmd += ' --color_aug' + if flip_aug: + run_cmd += ' --flip_aug' run_cmd += ( f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' ) @@ -825,6 +834,13 @@ def lora_tab( save_state = gr.Checkbox( label='Save training state', value=False ) + color_aug = gr.Checkbox( + label='Color augmentation', value=False + ) + flip_aug = gr.Checkbox( + label='Flip augmentation', value=False + ) + color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input]) with gr.Row(): resume = gr.Textbox( label='Resume from saved training state', @@ -879,7 +895,7 @@ def lora_tab( shuffle_caption, save_state, resume, - prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights + prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug ] button_open_config.click(