diff --git a/README.md b/README.md index b0e1a56..1673f24 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,11 @@ Once you have created the LoRA network you can generate images via auto1111 by i ## Change history +* 2023/01/02 (v19.2) update: + - Finetune, add xformers, 8bit adam, min bucket, max bucket, batch size and flip augmentation support for dataset preparation + - Finetune, add "Dataset preparation" tab to group task specific options +* 2023/01/01 (v19.2) update: + - add support for color and flip augmentation to "Dreambooth LoRA" * 2023/01/01 (v19.1) update: - merge kohys_ss upstream code updates - rework Dreambooth LoRA GUI diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 233e47d..0bfed1e 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -17,6 +17,7 @@ from library.common_gui import ( get_file_path, get_any_file_path, get_saveasfile_path, + color_aug_changed, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -66,6 +67,8 @@ def save_configuration( save_state, resume, prior_loss_weight, + color_aug, + flip_aug, ): original_file_path = file_path @@ -118,6 +121,8 @@ def save_configuration( 'save_state': save_state, 'resume': resume, 'prior_loss_weight': prior_loss_weight, + 'color_aug': color_aug, + 'flip_aug': flip_aug, } # Save the data to the selected file @@ -161,6 +166,8 @@ def open_configuration( save_state, resume, prior_loss_weight, + color_aug, + flip_aug, ): original_file_path = file_path @@ -214,6 +221,8 @@ def open_configuration( my_data.get('save_state', save_state), my_data.get('resume', resume), my_data.get('prior_loss_weight', prior_loss_weight), + my_data.get('color_aug', color_aug), + my_data.get('flip_aug', flip_aug), ) @@ -250,6 +259,8 @@ def train_model( save_state, resume, prior_loss_weight, + color_aug, + flip_aug, ): def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required @@ -377,6 +388,10 @@ def train_model( run_cmd += ' --shuffle_caption' if save_state: run_cmd += ' --save_state' + if color_aug: + run_cmd += ' --color_aug' + if flip_aug: + run_cmd += ' --flip_aug' run_cmd += ( f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' ) @@ -762,6 +777,15 @@ def dreambooth_tab( save_state = gr.Checkbox( label='Save training state', value=False ) + color_aug = gr.Checkbox( + label='Color augmentation', value=False + ) + flip_aug = gr.Checkbox(label='Flip augmentation', value=False) + color_aug.change( + color_aug_changed, + inputs=[color_aug], + outputs=[cache_latent_input], + ) with gr.Row(): resume = gr.Textbox( label='Resume from saved training state', @@ -773,7 +797,9 @@ def dreambooth_tab( label='Prior loss weight', value=1.0 ) with gr.Tab('Tools'): - gr.Markdown('This section provide Dreambooth tools to help setup your dataset...') + gr.Markdown( + 'This section provide Dreambooth tools to help setup your dataset...' + ) gradio_dreambooth_folder_creation_tab( train_data_dir_input=train_data_dir_input, reg_data_dir_input=reg_data_dir_input, @@ -820,6 +846,8 @@ def dreambooth_tab( save_state, resume, prior_loss_weight, + color_aug, + flip_aug, ], outputs=[ config_file_name, @@ -855,6 +883,8 @@ def dreambooth_tab( save_state, resume, prior_loss_weight, + color_aug, + flip_aug, ], ) @@ -895,6 +925,8 @@ def dreambooth_tab( save_state, resume, prior_loss_weight, + color_aug, + flip_aug, ], outputs=[config_file_name], ) @@ -936,6 +968,8 @@ def dreambooth_tab( save_state, resume, prior_loss_weight, + color_aug, + flip_aug, ], outputs=[config_file_name], ) @@ -975,6 +1009,8 @@ def dreambooth_tab( save_state, resume, prior_loss_weight, + color_aug, + flip_aug, ], ) diff --git a/finetune_gui.py b/finetune_gui.py index 3ae8685..67ad39b 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -31,6 +31,13 @@ def save_configuration( output_dir, logging_dir, max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, learning_rate, lr_scheduler, lr_warmup, @@ -43,10 +50,12 @@ def save_configuration( seed, num_cpu_threads_per_process, train_text_encoder, - create_buckets, create_caption, + create_buckets, save_model_as, caption_extension, + use_8bit_adam, + xformers, ): original_file_path = file_path @@ -75,6 +84,13 @@ def save_configuration( 'output_dir': output_dir, 'logging_dir': logging_dir, 'max_resolution': max_resolution, + 'min_bucket_reso': min_bucket_reso, + 'max_bucket_reso': max_bucket_reso, + 'batch_size': batch_size, + 'flip_aug': flip_aug, + 'caption_metadata_filename': caption_metadata_filename, + 'latent_metadata_filename': latent_metadata_filename, + 'full_path': full_path, 'learning_rate': learning_rate, 'lr_scheduler': lr_scheduler, 'lr_warmup': lr_warmup, @@ -91,6 +107,8 @@ def save_configuration( 'create_caption': create_caption, 'save_model_as': save_model_as, 'caption_extension': caption_extension, + 'use_8bit_adam': use_8bit_adam, + 'xformers': xformers, } # Save the data to the selected file @@ -110,6 +128,13 @@ def open_config_file( output_dir, logging_dir, max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, learning_rate, lr_scheduler, lr_warmup, @@ -122,10 +147,12 @@ def open_config_file( seed, num_cpu_threads_per_process, train_text_encoder, - create_buckets, create_caption, + create_buckets, save_model_as, caption_extension, + use_8bit_adam, + xformers, ): original_file_path = file_path file_path = get_file_path(file_path) @@ -152,6 +179,13 @@ def open_config_file( my_data.get('output_dir', output_dir), my_data.get('logging_dir', logging_dir), my_data.get('max_resolution', max_resolution), + my_data.get('min_bucket_reso', min_bucket_reso), + my_data.get('max_bucket_reso', max_bucket_reso), + my_data.get('batch_size', batch_size), + my_data.get('flip_aug', flip_aug), + my_data.get('caption_metadata_filename', caption_metadata_filename), + my_data.get('latent_metadata_filename', latent_metadata_filename), + my_data.get('full_path', full_path), my_data.get('learning_rate', learning_rate), my_data.get('lr_scheduler', lr_scheduler), my_data.get('lr_warmup', lr_warmup), @@ -170,12 +204,12 @@ def open_config_file( my_data.get('create_caption', create_caption), my_data.get('save_model_as', save_model_as), my_data.get('caption_extension', caption_extension), + my_data.get('use_8bit_adam', use_8bit_adam), + my_data.get('xformers', xformers), ) def train_model( - generate_caption_database, - generate_image_buckets, pretrained_model_name_or_path, v2, v_parameterization, @@ -184,6 +218,13 @@ def train_model( output_dir, logging_dir, max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, learning_rate, lr_scheduler, lr_warmup, @@ -196,8 +237,12 @@ def train_model( seed, num_cpu_threads_per_process, train_text_encoder, + generate_caption_database, + generate_image_buckets, save_model_as, caption_extension, + use_8bit_adam, + xformers, ): def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required @@ -227,8 +272,9 @@ def train_model( else: run_cmd += f' --caption_extension={caption_extension}' run_cmd += f' {image_folder}' - run_cmd += f' {train_dir}/meta_cap.json' - run_cmd += f' --full_path' + run_cmd += f' {train_dir}/{caption_metadata_filename}' + if full_path: + run_cmd += f' --full_path' print(run_cmd) @@ -237,26 +283,27 @@ def train_model( # create images buckets if generate_image_buckets: - command = [ - './venv/Scripts/python.exe', - 'finetune/prepare_buckets_latents.py', - image_folder, - '{}/meta_cap.json'.format(train_dir), - '{}/meta_lat.json'.format(train_dir), - pretrained_model_name_or_path, - '--batch_size', - '4', - '--max_resolution', - max_resolution, - '--mixed_precision', - mixed_precision, - '--full_path', - ] + run_cmd = ( + f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py' + ) + run_cmd += f' {image_folder}' + run_cmd += f' {train_dir}/{caption_metadata_filename}' + run_cmd += f' {train_dir}/{latent_metadata_filename}' + run_cmd += f' {pretrained_model_name_or_path}' + run_cmd += f' --batch_size={batch_size}' + run_cmd += f' --max_resolution={max_resolution}' + run_cmd += f' --min_bucket_reso={min_bucket_reso}' + run_cmd += f' --max_bucket_reso={max_bucket_reso}' + run_cmd += f' --mixed_precision={mixed_precision}' + if flip_aug: + run_cmd += f' --flip_aug' + if full_path: + run_cmd += f' --full_path' - print(command) + print(run_cmd) # Run the command - subprocess.run(command) + subprocess.run(run_cmd) image_num = len( [f for f in os.listdir(image_folder) if f.endswith('.npz')] @@ -270,11 +317,14 @@ def train_model( max_train_steps = int( math.ceil(float(repeats) / int(train_batch_size) * int(epoch)) ) + + # Divide by two because flip augmentation create two copied of the source images + if flip_aug: + max_train_steps = int(math.ceil(float(max_train_steps) / 2)) + print(f'max_train_steps = {max_train_steps}') - lr_warmup_steps = round( - float(int(lr_warmup) * int(max_train_steps) / 100) - ) + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) print(f'lr_warmup_steps = {lr_warmup_steps}') run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"' @@ -284,10 +334,14 @@ def train_model( run_cmd += ' --v_parameterization' if train_text_encoder: run_cmd += ' --train_text_encoder' + if use_8bit_adam: + run_cmd += f' --use_8bit_adam' + if xformers: + run_cmd += f' --xformers' run_cmd += ( f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' ) - run_cmd += f' --in_json={train_dir}/meta_lat.json' + run_cmd += f' --in_json={train_dir}/{latent_metadata_filename}' run_cmd += f' --train_data_dir={image_folder}' run_cmd += f' --output_dir={output_dir}' if not logging_dir == '': @@ -298,8 +352,6 @@ def train_model( run_cmd += f' --lr_scheduler={lr_scheduler}' run_cmd += f' --lr_warmup_steps={lr_warmup_steps}' run_cmd += f' --max_train_steps={max_train_steps}' - run_cmd += f' --use_8bit_adam' - run_cmd += f' --xformers' run_cmd += f' --mixed_precision={mixed_precision}' run_cmd += f' --save_every_n_epochs={save_every_n_epochs}' run_cmd += f' --seed={seed}' @@ -389,9 +441,9 @@ def UI(username, password): interface = gr.Blocks(css=css) with interface: - with gr.Tab("Finetune"): + with gr.Tab('Finetune'): finetune_tab() - with gr.Tab("Utilities"): + with gr.Tab('Utilities'): utilities_tab(enable_dreambooth_tab=False) # Show the interface @@ -400,12 +452,11 @@ def UI(username, password): else: interface.launch() + def finetune_tab(): dummy_ft_true = gr.Label(value=True, visible=False) dummy_ft_false = gr.Label(value=False, visible=False) - gr.Markdown( - 'Train a custom model using kohya finetune python code...' - ) + gr.Markdown('Train a custom model using kohya finetune python code...') with gr.Accordion('Configuration file', open=False): with gr.Row(): button_open_config = gr.Button( @@ -496,9 +547,7 @@ def finetune_tab(): train_dir_folder = gr.Button( folder_symbol, elem_id='open_folder_small' ) - train_dir_folder.click( - get_folder_path, outputs=train_dir_input - ) + train_dir_folder.click(get_folder_path, outputs=train_dir_input) image_folder_input = gr.Textbox( label='Training Image folder', @@ -547,11 +596,31 @@ def finetune_tab(): inputs=[output_dir_input], outputs=[output_dir_input], ) + with gr.Tab('Dataset preparation'): + with gr.Row(): + max_resolution_input = gr.Textbox( + label='Resolution (width,height)', value='512,512' + ) + min_bucket_reso = gr.Textbox( + label='Min bucket resolution', value='256' + ) + max_bucket_reso = gr.Textbox( + label='Max bucket resolution', value='1024' + ) + batch_size = gr.Textbox(label='Batch size', value='1') + with gr.Accordion('Advanced parameters', open=False): + with gr.Row(): + caption_metadata_filename = gr.Textbox( + label='Caption metadata filename', value='meta_cap.json' + ) + latent_metadata_filename = gr.Textbox( + label='Latent metadata filename', value='meta_lat.json' + ) + full_path = gr.Checkbox(label='Use full path', value=True) + flip_aug = gr.Checkbox(label='Flip augmentation', value=False) with gr.Tab('Training parameters'): with gr.Row(): - learning_rate_input = gr.Textbox( - label='Learning rate', value=1e-6 - ) + learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6) lr_scheduler_input = gr.Dropdown( label='LR Scheduler', choices=[ @@ -606,11 +675,7 @@ def finetune_tab(): label='Number of CPU threads per process', value=os.cpu_count(), ) - with gr.Row(): seed_input = gr.Textbox(label='Seed', value=1234) - max_resolution_input = gr.Textbox( - label='Max resolution', value='512,512' - ) with gr.Row(): caption_extention_input = gr.Textbox( label='Caption Extension', @@ -619,168 +684,74 @@ def finetune_tab(): train_text_encoder_input = gr.Checkbox( label='Train text encoder', value=True ) + with gr.Accordion('Advanced parameters', open=False): + with gr.Row(): + use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) + xformers = gr.Checkbox(label='Use xformers', value=True) with gr.Box(): with gr.Row(): create_caption = gr.Checkbox( - label='Generate caption database', value=True + label='Generate caption metadata', value=True ) create_buckets = gr.Checkbox( - label='Generate image buckets', value=True + label='Generate image buckets metadata', value=True ) button_run = gr.Button('Train model') - button_run.click( - train_model, - inputs=[ - create_caption, - create_buckets, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - train_dir_input, - image_folder_input, - output_dir_input, - logging_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - dataset_repeats_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - train_text_encoder_input, - save_model_as_dropdown, - caption_extention_input, - ], - ) + settings_list = [ + pretrained_model_name_or_path_input, + v2_input, + v_parameterization_input, + train_dir_input, + image_folder_input, + output_dir_input, + logging_dir_input, + max_resolution_input, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate_input, + lr_scheduler_input, + lr_warmup_input, + dataset_repeats_input, + train_batch_size_input, + epoch_input, + save_every_n_epochs_input, + mixed_precision_input, + save_precision_input, + seed_input, + num_cpu_threads_per_process_input, + train_text_encoder_input, + create_caption, + create_buckets, + save_model_as_dropdown, + caption_extention_input, + use_8bit_adam, + xformers, + ] + + button_run.click(train_model, inputs=settings_list) button_open_config.click( open_config_file, - inputs=[ - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - train_dir_input, - image_folder_input, - output_dir_input, - logging_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - dataset_repeats_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - train_text_encoder_input, - create_buckets, - create_caption, - save_model_as_dropdown, - caption_extention_input, - ], - outputs=[ - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - train_dir_input, - image_folder_input, - output_dir_input, - logging_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - dataset_repeats_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - train_text_encoder_input, - create_buckets, - create_caption, - save_model_as_dropdown, - caption_extention_input, - ], + inputs=[config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, ) button_save_config.click( save_configuration, - inputs=[ - dummy_ft_false, - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - train_dir_input, - image_folder_input, - output_dir_input, - logging_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - dataset_repeats_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - train_text_encoder_input, - create_buckets, - create_caption, - save_model_as_dropdown, - caption_extention_input, - ], + inputs=[dummy_ft_false, config_file_name] + settings_list, outputs=[config_file_name], ) button_save_as_config.click( save_configuration, - inputs=[ - dummy_ft_true, - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - train_dir_input, - image_folder_input, - output_dir_input, - logging_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - dataset_repeats_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - train_text_encoder_input, - create_buckets, - create_caption, - save_model_as_dropdown, - caption_extention_input, - ], + inputs=[dummy_ft_true, config_file_name] + settings_list, outputs=[config_file_name], ) diff --git a/library/common_gui.py b/library/common_gui.py index ae1e647..c30c0d3 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -1,5 +1,7 @@ from tkinter import filedialog, Tk import os +import gradio as gr +from easygui import msgbox def get_file_path(file_path='', defaultextension='.json'): @@ -107,3 +109,10 @@ def add_pre_postfix( f.seek(0, 0) f.write(f'{prefix}{content}{postfix}') f.close() + +def color_aug_changed(color_aug): + if color_aug: + msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...') + return gr.Checkbox.update(value=False, interactive=False) + else: + return gr.Checkbox.update(value=True, interactive=True) \ No newline at end of file diff --git a/lora_gui.py b/lora_gui.py index e8bed70..e33c7da 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -17,6 +17,7 @@ from library.common_gui import ( get_file_path, get_any_file_path, get_saveasfile_path, + color_aug_changed, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -64,7 +65,13 @@ def save_configuration( shuffle_caption, save_state, resume, - prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + color_aug, + flip_aug, ): original_file_path = file_path @@ -120,6 +127,8 @@ def save_configuration( 'unet_lr': unet_lr, 'network_dim': network_dim, 'lora_network_weights': lora_network_weights, + 'color_aug': color_aug, + 'flip_aug': flip_aug, } # Save the data to the selected file @@ -161,7 +170,13 @@ def open_configuration( shuffle_caption, save_state, resume, - prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + color_aug, + flip_aug, ): original_file_path = file_path @@ -218,6 +233,8 @@ def open_configuration( my_data.get('unet_lr', unet_lr), my_data.get('network_dim', network_dim), my_data.get('lora_network_weights', lora_network_weights), + my_data.get('color_aug', color_aug), + my_data.get('flip_aug', flip_aug), ) @@ -252,7 +269,13 @@ def train_model( shuffle_caption, save_state, resume, - prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + color_aug, + flip_aug, ): def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required @@ -289,13 +312,17 @@ def train_model( if output_dir == '': msgbox('Output folder path is missing') return - + # If string is empty set string to 0. - if text_encoder_lr == '': text_encoder_lr = 0 - if unet_lr == '': unet_lr = 0 - + if text_encoder_lr == '': + text_encoder_lr = 0 + if unet_lr == '': + unet_lr = 0 + if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0): - msgbox('At least one Learning Rate value for "Text encoder" or "Unet" need to be provided') + msgbox( + 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided' + ) return # Get a list of all subfolders in train_data_dir @@ -388,6 +415,10 @@ def train_model( run_cmd += ' --shuffle_caption' if save_state: run_cmd += ' --save_state' + if color_aug: + run_cmd += ' --color_aug' + if flip_aug: + run_cmd += ' --flip_aug' run_cmd += ( f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' ) @@ -436,7 +467,6 @@ def train_model( run_cmd += f' --network_dim={network_dim}' if not lora_network_weights == '': run_cmd += f' --network_weights={lora_network_weights}' - print(run_cmd) # Run the command @@ -543,7 +573,9 @@ def lora_tab( ): dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) - gr.Markdown('Train a custom model using kohya train network LoRA python code...') + gr.Markdown( + 'Train a custom model using kohya train network LoRA python code...' + ) with gr.Accordion('Configuration file', open=False): with gr.Row(): button_open_config = gr.Button('Open 📂', elem_id='open_folder') @@ -606,7 +638,7 @@ def lora_tab( ], value='same as source model', ) - + with gr.Row(): v2_input = gr.Checkbox(label='v2', value=True) v_parameterization_input = gr.Checkbox( @@ -720,8 +752,14 @@ def lora_tab( ) lr_warmup_input = gr.Textbox(label='LR warmup', value=0) with gr.Row(): - text_encoder_lr = gr.Textbox(label='Text Encoder learning rate', value=1e-6, placeholder='Optional') - unet_lr = gr.Textbox(label='Unet learning rate', value=1e-4, placeholder='Optional') + text_encoder_lr = gr.Textbox( + label='Text Encoder learning rate', + value=1e-6, + placeholder='Optional', + ) + unet_lr = gr.Textbox( + label='Unet learning rate', value=1e-4, placeholder='Optional' + ) # network_train = gr.Dropdown( # label='Network to train', # choices=[ @@ -738,7 +776,7 @@ def lora_tab( label='Network Dimension', value=4, step=1, - interactive=True + interactive=True, ) with gr.Row(): train_batch_size_input = gr.Slider( @@ -825,6 +863,15 @@ def lora_tab( save_state = gr.Checkbox( label='Save training state', value=False ) + color_aug = gr.Checkbox( + label='Color augmentation', value=False + ) + flip_aug = gr.Checkbox(label='Flip augmentation', value=False) + color_aug.change( + color_aug_changed, + inputs=[color_aug], + outputs=[cache_latent_input], + ) with gr.Row(): resume = gr.Textbox( label='Resume from saved training state', @@ -836,7 +883,9 @@ def lora_tab( label='Prior loss weight', value=1.0 ) with gr.Tab('Tools'): - gr.Markdown('This section provide Dreambooth tools to help setup your dataset...') + gr.Markdown( + 'This section provide Dreambooth tools to help setup your dataset...' + ) gradio_dreambooth_folder_creation_tab( train_data_dir_input=train_data_dir_input, reg_data_dir_input=reg_data_dir_input, @@ -846,7 +895,7 @@ def lora_tab( gradio_dataset_balancing_tab() button_run = gr.Button('Train model') - + settings_list = [ pretrained_model_name_or_path_input, v2_input, @@ -879,7 +928,13 @@ def lora_tab( shuffle_caption, save_state, resume, - prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + color_aug, + flip_aug, ] button_open_config.click( @@ -902,7 +957,7 @@ def lora_tab( button_run.click( train_model, - inputs=settings_list, + inputs=settings_list, ) return (