From 123cf4e3c5558e118e010b9bcf2b5db8fcbe7af0 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 15 Jan 2023 15:03:04 -0500 Subject: [PATCH] Move cxommon adv train parm to common gui --- dreambooth_gui.py | 88 +++++++++++++++++++++---------------------- finetune_gui.py | 74 ++++++++++++++++++------------------ library/common_gui.py | 70 ++++++++++++++++++++++++++++++---- lora_gui.py | 71 +++++++++++----------------------- 4 files changed, 166 insertions(+), 137 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index a7b67ae..393447c 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -77,7 +77,8 @@ def save_configuration( output_name, max_token_length, max_train_epochs, - max_data_loader_n_workers, + max_data_loader_n_workers,mem_eff_attn, + gradient_accumulation_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -158,7 +159,8 @@ def open_configuration( output_name, max_token_length, max_train_epochs, - max_data_loader_n_workers, + max_data_loader_n_workers,mem_eff_attn, + gradient_accumulation_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -223,7 +225,8 @@ def train_model( output_name, max_token_length, max_train_epochs, - max_data_loader_n_workers, + max_data_loader_n_workers,mem_eff_attn, + gradient_accumulation_steps, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -322,24 +325,12 @@ def train_model( run_cmd += ' --cache_latents' if enable_bucket: run_cmd += ' --enable_bucket' - if gradient_checkpointing: - run_cmd += ' --gradient_checkpointing' - if full_fp16: - run_cmd += ' --full_fp16' if no_token_padding: run_cmd += ' --no_token_padding' if use_8bit_adam: run_cmd += ' --use_8bit_adam' if xformers: run_cmd += ' --xformers' - if shuffle_caption: - run_cmd += ' --shuffle_caption' - # if save_state: - # run_cmd += ' --save_state' - if color_aug: - run_cmd += ' --color_aug' - if flip_aug: - run_cmd += ' --flip_aug' run_cmd += ( f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' ) @@ -353,8 +344,6 @@ def train_model( run_cmd += f' --lr_scheduler={lr_scheduler}' run_cmd += f' --lr_warmup_steps={lr_warmup_steps}' run_cmd += f' --max_train_steps={max_train_steps}' - run_cmd += f' --use_8bit_adam' - run_cmd += f' --xformers' run_cmd += f' --mixed_precision={mixed_precision}' run_cmd += f' --save_every_n_epochs={save_every_n_epochs}' run_cmd += f' --seed={seed}' @@ -372,8 +361,6 @@ def train_model( # run_cmd += f' --resume={resume}' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' - if int(clip_skip) > 1: - run_cmd += f' --clip_skip={str(clip_skip)}' if not vae == '': run_cmd += f' --vae="{vae}"' if not output_name == '': @@ -384,12 +371,23 @@ def train_model( run_cmd += f' --max_train_epochs="{max_train_epochs}"' if not max_data_loader_n_workers == '': run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' + if int(gradient_accumulation_steps) > 1: + run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' run_cmd += run_cmd_advanced_training( max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers, max_token_length=max_token_length, resume=resume, save_state=save_state, + mem_eff_attn=mem_eff_attn, + clip_skip=clip_skip, + flip_aug=flip_aug, + color_aug=color_aug, + shuffle_caption=shuffle_caption, + gradient_checkpointing=gradient_checkpointing, + full_fp16=full_fp16, + xformers=xformers, + use_8bit_adam=use_8bit_adam, ) print(run_cmd) @@ -668,36 +666,13 @@ def dreambooth_tab( with gr.Row(): enable_bucket = gr.Checkbox(label='Enable buckets', value=True) cache_latent = gr.Checkbox(label='Cache latent', value=True) - use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) - xformers = gr.Checkbox(label='Use xformers', value=True) with gr.Accordion('Advanced Configuration', open=False): with gr.Row(): - full_fp16 = gr.Checkbox( - label='Full fp16 training (experimental)', value=False - ) no_token_padding = gr.Checkbox( label='No token padding', value=False ) - - gradient_checkpointing = gr.Checkbox( - label='Gradient checkpointing', value=False - ) - - shuffle_caption = gr.Checkbox( - label='Shuffle caption', value=False - ) - with gr.Row(): - color_aug = gr.Checkbox( - label='Color augmentation', value=False - ) - flip_aug = gr.Checkbox(label='Flip augmentation', value=False) - color_aug.change( - color_aug_changed, - inputs=[color_aug], - outputs=[cache_latent], - ) - clip_skip = gr.Slider( - label='Clip skip', value='1', minimum=1, maximum=12, step=1 + gradient_accumulation_steps = gr.Number( + label='Gradient accumulate steps', value='1' ) with gr.Row(): prior_loss_weight = gr.Number( @@ -709,7 +684,27 @@ def dreambooth_tab( ) vae_button = gr.Button('📂', elem_id='open_folder_small') vae_button.click(get_any_file_path, outputs=vae) - save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training() + ( + use_8bit_adam, + xformers, + full_fp16, + gradient_checkpointing, + shuffle_caption, + color_aug, + flip_aug, + clip_skip, + mem_eff_attn, + save_state, + resume, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + ) = gradio_advanced_training() + color_aug.change( + color_aug_changed, + inputs=[color_aug], + outputs=[cache_latent], + ) with gr.Tab('Tools'): gr.Markdown( 'This section provide Dreambooth tools to help setup your dataset...' @@ -763,7 +758,8 @@ def dreambooth_tab( output_name, max_token_length, max_train_epochs, - max_data_loader_n_workers, + max_data_loader_n_workers,mem_eff_attn, + gradient_accumulation_steps, ] button_open_config.click( diff --git a/finetune_gui.py b/finetune_gui.py index 8fae590..adae1d8 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -13,7 +13,8 @@ from library.common_gui import ( get_saveasfile_path, save_inference_file, set_pretrained_model_name_or_path_input, - gradio_advanced_training,run_cmd_advanced_training + gradio_advanced_training,run_cmd_advanced_training, + color_aug_changed, ) from library.utilities import utilities_tab @@ -69,7 +70,7 @@ def save_configuration( output_name, max_token_length, max_train_epochs, - max_data_loader_n_workers, + max_data_loader_n_workers,full_fp16,color_aug, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -154,7 +155,7 @@ def open_config_file( output_name, max_token_length, max_train_epochs, - max_data_loader_n_workers, + max_data_loader_n_workers,full_fp16,color_aug, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -224,7 +225,7 @@ def train_model( output_name, max_token_length, max_train_epochs, - max_data_loader_n_workers, + max_data_loader_n_workers,full_fp16,color_aug, ): # create caption json file if generate_caption_database: @@ -262,8 +263,8 @@ def train_model( run_cmd += f' --min_bucket_reso={min_bucket_reso}' run_cmd += f' --max_bucket_reso={max_bucket_reso}' run_cmd += f' --mixed_precision={mixed_precision}' - if flip_aug: - run_cmd += f' --flip_aug' + # if flip_aug: + # run_cmd += f' --flip_aug' if full_path: run_cmd += f' --full_path' @@ -301,16 +302,6 @@ def train_model( run_cmd += ' --v_parameterization' if train_text_encoder: run_cmd += ' --train_text_encoder' - if use_8bit_adam: - run_cmd += f' --use_8bit_adam' - if xformers: - run_cmd += f' --xformers' - if gradient_checkpointing: - run_cmd += ' --gradient_checkpointing' - if mem_eff_attn: - run_cmd += ' --mem_eff_attn' - if shuffle_caption: - run_cmd += ' --shuffle_caption' run_cmd += ( f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' ) @@ -331,8 +322,6 @@ def train_model( run_cmd += f' --save_precision={save_precision}' if not save_model_as == 'same as source model': run_cmd += f' --save_model_as={save_model_as}' - if int(clip_skip) > 1: - run_cmd += f' --clip_skip={str(clip_skip)}' if int(gradient_accumulation_steps) > 1: run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' # if save_state: @@ -349,6 +338,15 @@ def train_model( max_token_length=max_token_length, resume=resume, save_state=save_state, + mem_eff_attn=mem_eff_attn, + clip_skip=clip_skip, + flip_aug=flip_aug, + color_aug=color_aug, + shuffle_caption=shuffle_caption, + gradient_checkpointing=gradient_checkpointing, + full_fp16=full_fp16, + xformers=xformers, + use_8bit_adam=use_8bit_adam, ) print(run_cmd) @@ -565,7 +563,6 @@ def finetune_tab(): label='Latent metadata filename', value='meta_lat.json' ) full_path = gr.Checkbox(label='Use full path', value=True) - flip_aug = gr.Checkbox(label='Flip augmentation', value=False) with gr.Tab('Training parameters'): with gr.Row(): learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6) @@ -634,25 +631,30 @@ def finetune_tab(): ) with gr.Accordion('Advanced parameters', open=False): with gr.Row(): - use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) - xformers = gr.Checkbox(label='Use xformers', value=True) - clip_skip = gr.Slider( - label='Clip skip', value='1', minimum=1, maximum=12, step=1 - ) - mem_eff_attn = gr.Checkbox( - label='Memory efficient attention', value=False - ) - shuffle_caption = gr.Checkbox( - label='Shuffle caption', value=False - ) - with gr.Row(): - gradient_checkpointing = gr.Checkbox( - label='Gradient checkpointing', value=False - ) gradient_accumulation_steps = gr.Number( label='Gradient accumulate steps', value='1' ) - save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training() + ( + use_8bit_adam, + xformers, + full_fp16, + gradient_checkpointing, + shuffle_caption, + color_aug, + flip_aug, + clip_skip, + mem_eff_attn, + save_state, + resume, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + ) = gradio_advanced_training() + # color_aug.change( + # color_aug_changed, + # inputs=[color_aug], + # # outputs=[cache_latent], # Not applicable to fine_tune.py + # ) with gr.Box(): with gr.Row(): create_caption = gr.Checkbox( @@ -708,7 +710,7 @@ def finetune_tab(): output_name, max_token_length, max_train_epochs, - max_data_loader_n_workers, + max_data_loader_n_workers,full_fp16,color_aug, ] button_run.click(train_model, inputs=settings_list) diff --git a/library/common_gui.py b/library/common_gui.py index 1d193d5..a7b798b 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -304,13 +304,28 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): def gradio_advanced_training(): with gr.Row(): - save_state = gr.Checkbox(label='Save training state', value=False) - resume = gr.Textbox( - label='Resume from saved training state', - placeholder='path to "last-state" state folder to resume from', + full_fp16 = gr.Checkbox( + label='Full fp16 training (experimental)', value=False + ) + gradient_checkpointing = gr.Checkbox( + label='Gradient checkpointing', value=False + ) + shuffle_caption = gr.Checkbox( + label='Shuffle caption', value=False + ) + use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) + xformers = gr.Checkbox(label='Use xformers', value=True) + with gr.Row(): + color_aug = gr.Checkbox( + label='Color augmentation', value=False + ) + flip_aug = gr.Checkbox(label='Flip augmentation', value=False) + clip_skip = gr.Slider( + label='Clip skip', value='1', minimum=1, maximum=12, step=1 + ) + mem_eff_attn = gr.Checkbox( + label='Memory efficient attention', value=False ) - resume_button = gr.Button('📂', elem_id='open_folder_small') - resume_button.click(get_folder_path, outputs=resume) max_token_length = gr.Dropdown( label='Max Token Length', choices=[ @@ -321,6 +336,13 @@ def gradio_advanced_training(): value='75', ) with gr.Row(): + save_state = gr.Checkbox(label='Save training state', value=False) + resume = gr.Textbox( + label='Resume from saved training state', + placeholder='path to "last-state" state folder to resume from', + ) + resume_button = gr.Button('📂', elem_id='open_folder_small') + resume_button.click(get_folder_path, outputs=resume) max_train_epochs = gr.Textbox( label='Max train epoch', placeholder='(Optional) Override number of epoch', @@ -330,6 +352,15 @@ def gradio_advanced_training(): placeholder='(Optional) Override number of epoch. Default: 8', ) return ( + use_8bit_adam, + xformers, + full_fp16, + gradient_checkpointing, + shuffle_caption, + color_aug, + flip_aug, + clip_skip, + mem_eff_attn, save_state, resume, max_token_length, @@ -343,16 +374,41 @@ def run_cmd_advanced_training(**kwargs): f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"' if kwargs.get('max_train_epochs') else '', + f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"' if kwargs.get('max_data_loader_n_workers') else '', + f' --max_token_length={kwargs.get("max_token_length", "")}' - if int(kwargs.get('max_token_length', 0)) > 75 + if int(kwargs.get('max_token_length', 75)) > 75 else '', + + f' --clip_skip={kwargs.get("clip_skip", "")}' + if int(kwargs.get('clip_skip', 1)) > 1 + else '', + f' --resume="{kwargs.get("resume", "")}"' if kwargs.get('resume') else '', + ' --save_state' if kwargs.get('save_state') else '', + + ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '', + + ' --color_aug' if kwargs.get('color_aug') else '', + + ' --flip_aug' if kwargs.get('flip_aug') else '', + + ' --shuffle_caption' if kwargs.get('shuffle_caption') else '', + + ' --gradient_checkpointing' if kwargs.get('gradient_checkpointing') else '', + + ' --full_fp16' if kwargs.get('full_fp16') else '', + + ' --xformers' if kwargs.get('xformers') else '', + + ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '', + ] run_cmd = ''.join(options) return run_cmd diff --git a/lora_gui.py b/lora_gui.py index 2432b7a..80b196e 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -340,26 +340,8 @@ def train_model( run_cmd += ' --cache_latents' if enable_bucket: run_cmd += ' --enable_bucket' - if gradient_checkpointing: - run_cmd += ' --gradient_checkpointing' - if full_fp16: - run_cmd += ' --full_fp16' if no_token_padding: run_cmd += ' --no_token_padding' - if use_8bit_adam: - run_cmd += ' --use_8bit_adam' - if xformers: - run_cmd += ' --xformers' - if shuffle_caption: - run_cmd += ' --shuffle_caption' - # if save_state: - # run_cmd += ' --save_state' - if color_aug: - run_cmd += ' --color_aug' - if flip_aug: - run_cmd += ' --flip_aug' - if mem_eff_attn: - run_cmd += ' --mem_eff_attn' run_cmd += ( f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' ) @@ -408,8 +390,6 @@ def train_model( run_cmd += f' --network_dim={network_dim}' if not lora_network_weights == '': run_cmd += f' --network_weights="{lora_network_weights}"' - if int(clip_skip) > 1: - run_cmd += f' --clip_skip={str(clip_skip)}' if int(gradient_accumulation_steps) > 1: run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' # if not vae == '': @@ -424,6 +404,15 @@ def train_model( max_token_length=max_token_length, resume=resume, save_state=save_state, + mem_eff_attn=mem_eff_attn, + clip_skip=clip_skip, + flip_aug=flip_aug, + color_aug=color_aug, + shuffle_caption=shuffle_caption, + gradient_checkpointing=gradient_checkpointing, + full_fp16=full_fp16, + xformers=xformers, + use_8bit_adam=use_8bit_adam, ) print(run_cmd) @@ -729,53 +718,39 @@ def lora_tab( with gr.Row(): enable_bucket = gr.Checkbox(label='Enable buckets', value=True) cache_latent = gr.Checkbox(label='Cache latent', value=True) - use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) - xformers = gr.Checkbox(label='Use xformers', value=True) with gr.Accordion('Advanced Configuration', open=False): with gr.Row(): - full_fp16 = gr.Checkbox( - label='Full fp16 training (experimental)', value=False - ) no_token_padding = gr.Checkbox( label='No token padding', value=False ) - - gradient_checkpointing = gr.Checkbox( - label='Gradient checkpointing', value=False - ) gradient_accumulation_steps = gr.Number( label='Gradient accumulate steps', value='1' ) - - shuffle_caption = gr.Checkbox( - label='Shuffle caption', value=False - ) with gr.Row(): prior_loss_weight = gr.Number( label='Prior loss weight', value=1.0 ) - color_aug = gr.Checkbox( - label='Color augmentation', value=False - ) - flip_aug = gr.Checkbox(label='Flip augmentation', value=False) - color_aug.change( - color_aug_changed, - inputs=[color_aug], - outputs=[cache_latent], - ) - clip_skip = gr.Slider( - label='Clip skip', value='1', minimum=1, maximum=12, step=1 - ) - mem_eff_attn = gr.Checkbox( - label='Memory efficient attention', value=False - ) ( + use_8bit_adam, + xformers, + full_fp16, + gradient_checkpointing, + shuffle_caption, + color_aug, + flip_aug, + clip_skip, + mem_eff_attn, save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers, ) = gradio_advanced_training() + color_aug.change( + color_aug_changed, + inputs=[color_aug], + outputs=[cache_latent], + ) with gr.Tab('Tools'): gr.Markdown(