From 202923b3ce8552b5de79da7d9303e5f8a646e3f5 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Fri, 27 Jan 2023 07:33:44 -0500 Subject: [PATCH] Add support for --keep_token option --- README.md | 1 + dreambooth_gui.py | 8 ++++++-- finetune_gui.py | 9 ++++++--- library/common_gui.py | 8 ++++++++ lora_gui.py | 9 ++++++--- textual_inversion_gui.py | 9 ++++++--- 6 files changed, 33 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e37f3a1..8204902 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,7 @@ Then redo the installation instruction within the kohya_ss venv. ## Change history * 2023/01/27 (v20.5.1): + - Fix issue: https://github.com/bmaltais/kohya_ss/issues/70 - Fix issue https://github.com/bmaltais/kohya_ss/issues/71 * 2023/01/26 (v20.5.0): - Add new `Dreambooth TI` tab for training of Textual Inversion embeddings diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 9fe6573..24ccfe7 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -82,7 +82,7 @@ def save_configuration( max_data_loader_n_workers, mem_eff_attn, gradient_accumulation_steps, - model_list, + model_list, keep_tokens, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -166,7 +166,7 @@ def open_configuration( max_data_loader_n_workers, mem_eff_attn, gradient_accumulation_steps, - model_list, + model_list, keep_tokens, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -235,6 +235,7 @@ def train_model( mem_eff_attn, gradient_accumulation_steps, model_list, # Keep this. Yes, it is unused here but required given the common list used + keep_tokens, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -396,6 +397,7 @@ def train_model( full_fp16=full_fp16, xformers=xformers, use_8bit_adam=use_8bit_adam, + keep_tokens=keep_tokens, ) print(run_cmd) @@ -602,6 +604,7 @@ def dreambooth_tab( max_token_length, max_train_epochs, max_data_loader_n_workers, + keep_tokens, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -665,6 +668,7 @@ def dreambooth_tab( mem_eff_attn, gradient_accumulation_steps, model_list, + keep_tokens, ] button_open_config.click( diff --git a/finetune_gui.py b/finetune_gui.py index e0e918c..4597649 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -78,7 +78,7 @@ def save_configuration( color_aug, model_list, cache_latents, - use_latent_files, + use_latent_files, keep_tokens, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -168,7 +168,7 @@ def open_config_file( color_aug, model_list, cache_latents, - use_latent_files, + use_latent_files, keep_tokens, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -243,7 +243,7 @@ def train_model( color_aug, model_list, # Keep this. Yes, it is unused here but required given the common list used cache_latents, - use_latent_files, + use_latent_files, keep_tokens, ): # create caption json file if generate_caption_database: @@ -381,6 +381,7 @@ def train_model( full_fp16=full_fp16, xformers=xformers, use_8bit_adam=use_8bit_adam, + keep_tokens=keep_tokens, ) print(run_cmd) @@ -585,6 +586,7 @@ def finetune_tab(): max_token_length, max_train_epochs, max_data_loader_n_workers, + keep_tokens, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -644,6 +646,7 @@ def finetune_tab(): model_list, cache_latents, use_latent_files, + keep_tokens, ] button_run.click(train_model, inputs=settings_list) diff --git a/library/common_gui.py b/library/common_gui.py index 4b5e8dc..816fd9d 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -519,6 +519,9 @@ def gradio_advanced_training(): shuffle_caption = gr.Checkbox( label='Shuffle caption', value=False ) + keep_tokens = gr.Slider( + label='Keen n tokens', value='0', minimum=0, maximum=32, step=1 + ) use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) xformers = gr.Checkbox(label='Use xformers', value=True) with gr.Row(): @@ -572,6 +575,7 @@ def gradio_advanced_training(): max_token_length, max_train_epochs, max_data_loader_n_workers, + keep_tokens, ) def run_cmd_advanced_training(**kwargs): @@ -596,6 +600,10 @@ def run_cmd_advanced_training(**kwargs): if kwargs.get('resume') else '', + f' --keep_tokens="{kwargs.get("keep_tokens", "")}"' + if int(kwargs.get('keep_tokens', 0)) > 0 + else '', + ' --save_state' if kwargs.get('save_state') else '', ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '', diff --git a/lora_gui.py b/lora_gui.py index 725d4ef..49359e9 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -90,7 +90,7 @@ def save_configuration( max_train_epochs, max_data_loader_n_workers, network_alpha, - training_comment, + training_comment, keep_tokens, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -179,7 +179,7 @@ def open_configuration( max_train_epochs, max_data_loader_n_workers, network_alpha, - training_comment, + training_comment, keep_tokens, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -252,7 +252,7 @@ def train_model( max_train_epochs, max_data_loader_n_workers, network_alpha, - training_comment, + training_comment, keep_tokens, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -425,6 +425,7 @@ def train_model( full_fp16=full_fp16, xformers=xformers, use_8bit_adam=use_8bit_adam, + keep_tokens=keep_tokens, ) print(run_cmd) @@ -660,6 +661,7 @@ def lora_tab( max_token_length, max_train_epochs, max_data_loader_n_workers, + keep_tokens, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -733,6 +735,7 @@ def lora_tab( max_data_loader_n_workers, network_alpha, training_comment, + keep_tokens, ] button_open_config.click( diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 3380d60..5dd6015 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -82,7 +82,7 @@ def save_configuration( max_data_loader_n_workers, mem_eff_attn, gradient_accumulation_steps, - model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, + model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -170,7 +170,7 @@ def open_configuration( max_data_loader_n_workers, mem_eff_attn, gradient_accumulation_steps, - model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, + model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -239,7 +239,7 @@ def train_model( mem_eff_attn, gradient_accumulation_steps, model_list, # Keep this. Yes, it is unused here but required given the common list used - token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, + token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -416,6 +416,7 @@ def train_model( full_fp16=full_fp16, xformers=xformers, use_8bit_adam=use_8bit_adam, + keep_tokens=keep_tokens, ) run_cmd += f' --token_string={token_string}' run_cmd += f' --init_word={init_word}' @@ -669,6 +670,7 @@ def ti_tab( max_token_length, max_train_epochs, max_data_loader_n_workers, + keep_tokens, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -733,6 +735,7 @@ def ti_tab( gradient_accumulation_steps, model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, + keep_tokens, ] button_open_config.click(