From 43116feda80dcdd63249a403e9e7e3fb10325d2a Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 10 Jan 2023 09:38:32 -0500 Subject: [PATCH] Add support for max token --- README.md | 8 ++++---- dreambooth_gui.py | 15 +++++++++++++++ finetune_gui.py | 15 +++++++++++++++ lora_gui.py | 18 +++++++++++++++++- 4 files changed, 51 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 30c4445..742bd9a 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ -# Kohya's dreambooth and finetuning +# Kohya's GUI -This repository now includes the solutions provided by Kohya_ss in a single location. I have combined both solutions under one repository to align with the new official Kohya repository where he will maintain his code from now on: https://github.com/kohya-ss/sd-scripts. - -A note accompanying the release of his new repository can be found here: https://note.com/kohya_ss/n/nba4eceaa4594 +This repository repository is providing a Gradio GUI for kohya's Stable Diffusion trainers found here: https://github.com/kohya-ss/sd-scripts. The GUI allow you to set the training parameters and generate and run the required CLI command to train the model. ## Required Dependencies @@ -101,6 +99,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i ## Change history +* 2023/01/11 (v20.2.0): + - Add support for max token lenght * 2023/01/10 (v20.1.1): - Fix issue with LoRA config loading * 2023/01/10 (v20.1): diff --git a/dreambooth_gui.py b/dreambooth_gui.py index f5a262f..f5ea216 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -73,6 +73,7 @@ def save_configuration( clip_skip, vae, output_name, + max_token_length, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -151,6 +152,7 @@ def open_configuration( clip_skip, vae, output_name, + max_token_length, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -213,6 +215,7 @@ def train_model( clip_skip, vae, output_name, + max_token_length, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -367,6 +370,8 @@ def train_model( run_cmd += f' --vae="{vae}"' if not output_name == '': run_cmd += f' --output_name="{output_name}"' + if (int(max_token_length) > 75): + run_cmd += f' --max_token_length={max_token_length}' print(run_cmd) # Run the command @@ -694,6 +699,15 @@ def dreambooth_tab( ) vae_button = gr.Button('📂', elem_id='open_folder_small') vae_button.click(get_any_file_path, outputs=vae) + max_token_length = gr.Dropdown( + label='Max Token Length', + choices=[ + '75', + '150', + '225', + ], + value='75', + ) with gr.Tab('Tools'): gr.Markdown( 'This section provide Dreambooth tools to help setup your dataset...' @@ -745,6 +759,7 @@ def dreambooth_tab( clip_skip, vae, output_name, + max_token_length, ] button_open_config.click( diff --git a/finetune_gui.py b/finetune_gui.py index 168b87e..6abadbf 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -66,6 +66,7 @@ def save_configuration( mem_eff_attn, shuffle_caption, output_name, + max_token_length, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -148,6 +149,7 @@ def open_config_file( mem_eff_attn, shuffle_caption, output_name, + max_token_length, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -215,6 +217,7 @@ def train_model( mem_eff_attn, shuffle_caption, output_name, + max_token_length, ): # create caption json file if generate_caption_database: @@ -331,6 +334,8 @@ def train_model( run_cmd += f' --resume={resume}' if not output_name == '': run_cmd += f' --output_name="{output_name}"' + if (int(max_token_length) > 75): + run_cmd += f' --max_token_length={max_token_length}' print(run_cmd) # Run the command @@ -642,6 +647,15 @@ def finetune_tab(): gradient_accumulation_steps = gr.Number( label='Gradient accumulate steps', value='1' ) + max_token_length = gr.Dropdown( + label='Max Token Length', + choices=[ + '75', + '150', + '225', + ], + value='75', + ) with gr.Box(): with gr.Row(): create_caption = gr.Checkbox( @@ -695,6 +709,7 @@ def finetune_tab(): mem_eff_attn, shuffle_caption, output_name, + max_token_length, ] button_run.click(train_model, inputs=settings_list) diff --git a/lora_gui.py b/lora_gui.py index 2f83d18..4c0c51f 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -80,6 +80,7 @@ def save_configuration( mem_eff_attn, output_name, model_list, + max_token_length, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -163,6 +164,7 @@ def open_configuration( mem_eff_attn, output_name, model_list, + max_token_length, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -229,7 +231,8 @@ def train_model( gradient_accumulation_steps, mem_eff_attn, output_name, - model_list, + model_list, # Keep this. Yes, it is unused here but required given the common list used + max_token_length, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -405,6 +408,8 @@ def train_model( # run_cmd += f' --vae="{vae}"' if not output_name == '': run_cmd += f' --output_name="{output_name}"' + if (int(max_token_length) > 75): + run_cmd += f' --max_token_length={max_token_length}' print(run_cmd) # Run the command @@ -781,6 +786,16 @@ def lora_tab( # ) # vae_button = gr.Button('📂', elem_id='open_folder_small') # vae_button.click(get_any_file_path, outputs=vae) + max_token_length = gr.Dropdown( + label='Max Token Length', + choices=[ + '75', + '150', + '225', + ], + value='75', + ) + with gr.Tab('Tools'): gr.Markdown( 'This section provide Dreambooth tools to help setup your dataset...' @@ -839,6 +854,7 @@ def lora_tab( mem_eff_attn, output_name, model_list, + max_token_length ] button_open_config.click(