Add support for max token

This commit is contained in:
bmaltais 2023-01-10 09:38:32 -05:00
parent 42a3646d4a
commit 43116feda8
4 changed files with 51 additions and 5 deletions

View File

@ -1,8 +1,6 @@
# Kohya's dreambooth and finetuning
# Kohya's GUI
This repository now includes the solutions provided by Kohya_ss in a single location. I have combined both solutions under one repository to align with the new official Kohya repository where he will maintain his code from now on: https://github.com/kohya-ss/sd-scripts.
A note accompanying the release of his new repository can be found here: https://note.com/kohya_ss/n/nba4eceaa4594
This repository repository is providing a Gradio GUI for kohya's Stable Diffusion trainers found here: https://github.com/kohya-ss/sd-scripts. The GUI allow you to set the training parameters and generate and run the required CLI command to train the model.
## Required Dependencies
@ -101,6 +99,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i
## Change history
* 2023/01/11 (v20.2.0):
- Add support for max token lenght
* 2023/01/10 (v20.1.1):
- Fix issue with LoRA config loading
* 2023/01/10 (v20.1):

View File

@ -73,6 +73,7 @@ def save_configuration(
clip_skip,
vae,
output_name,
max_token_length,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -151,6 +152,7 @@ def open_configuration(
clip_skip,
vae,
output_name,
max_token_length,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -213,6 +215,7 @@ def train_model(
clip_skip,
vae,
output_name,
max_token_length,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -367,6 +370,8 @@ def train_model(
run_cmd += f' --vae="{vae}"'
if not output_name == '':
run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}'
print(run_cmd)
# Run the command
@ -694,6 +699,15 @@ def dreambooth_tab(
)
vae_button = gr.Button('📂', elem_id='open_folder_small')
vae_button.click(get_any_file_path, outputs=vae)
max_token_length = gr.Dropdown(
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
with gr.Tab('Tools'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
@ -745,6 +759,7 @@ def dreambooth_tab(
clip_skip,
vae,
output_name,
max_token_length,
]
button_open_config.click(

View File

@ -66,6 +66,7 @@ def save_configuration(
mem_eff_attn,
shuffle_caption,
output_name,
max_token_length,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -148,6 +149,7 @@ def open_config_file(
mem_eff_attn,
shuffle_caption,
output_name,
max_token_length,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -215,6 +217,7 @@ def train_model(
mem_eff_attn,
shuffle_caption,
output_name,
max_token_length,
):
# create caption json file
if generate_caption_database:
@ -331,6 +334,8 @@ def train_model(
run_cmd += f' --resume={resume}'
if not output_name == '':
run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}'
print(run_cmd)
# Run the command
@ -642,6 +647,15 @@ def finetune_tab():
gradient_accumulation_steps = gr.Number(
label='Gradient accumulate steps', value='1'
)
max_token_length = gr.Dropdown(
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
with gr.Box():
with gr.Row():
create_caption = gr.Checkbox(
@ -695,6 +709,7 @@ def finetune_tab():
mem_eff_attn,
shuffle_caption,
output_name,
max_token_length,
]
button_run.click(train_model, inputs=settings_list)

View File

@ -80,6 +80,7 @@ def save_configuration(
mem_eff_attn,
output_name,
model_list,
max_token_length,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -163,6 +164,7 @@ def open_configuration(
mem_eff_attn,
output_name,
model_list,
max_token_length,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -229,7 +231,8 @@ def train_model(
gradient_accumulation_steps,
mem_eff_attn,
output_name,
model_list,
model_list, # Keep this. Yes, it is unused here but required given the common list used
max_token_length,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -405,6 +408,8 @@ def train_model(
# run_cmd += f' --vae="{vae}"'
if not output_name == '':
run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}'
print(run_cmd)
# Run the command
@ -781,6 +786,16 @@ def lora_tab(
# )
# vae_button = gr.Button('📂', elem_id='open_folder_small')
# vae_button.click(get_any_file_path, outputs=vae)
max_token_length = gr.Dropdown(
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
with gr.Tab('Tools'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
@ -839,6 +854,7 @@ def lora_tab(
mem_eff_attn,
output_name,
model_list,
max_token_length
]
button_open_config.click(