Add support for max token
This commit is contained in:
parent
42a3646d4a
commit
43116feda8
@ -1,8 +1,6 @@
|
|||||||
# Kohya's dreambooth and finetuning
|
# Kohya's GUI
|
||||||
|
|
||||||
This repository now includes the solutions provided by Kohya_ss in a single location. I have combined both solutions under one repository to align with the new official Kohya repository where he will maintain his code from now on: https://github.com/kohya-ss/sd-scripts.
|
This repository repository is providing a Gradio GUI for kohya's Stable Diffusion trainers found here: https://github.com/kohya-ss/sd-scripts. The GUI allow you to set the training parameters and generate and run the required CLI command to train the model.
|
||||||
|
|
||||||
A note accompanying the release of his new repository can be found here: https://note.com/kohya_ss/n/nba4eceaa4594
|
|
||||||
|
|
||||||
## Required Dependencies
|
## Required Dependencies
|
||||||
|
|
||||||
@ -101,6 +99,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 2023/01/11 (v20.2.0):
|
||||||
|
- Add support for max token lenght
|
||||||
* 2023/01/10 (v20.1.1):
|
* 2023/01/10 (v20.1.1):
|
||||||
- Fix issue with LoRA config loading
|
- Fix issue with LoRA config loading
|
||||||
* 2023/01/10 (v20.1):
|
* 2023/01/10 (v20.1):
|
||||||
|
@ -73,6 +73,7 @@ def save_configuration(
|
|||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
|
max_token_length,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -151,6 +152,7 @@ def open_configuration(
|
|||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
|
max_token_length,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -213,6 +215,7 @@ def train_model(
|
|||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
|
max_token_length,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -367,6 +370,8 @@ def train_model(
|
|||||||
run_cmd += f' --vae="{vae}"'
|
run_cmd += f' --vae="{vae}"'
|
||||||
if not output_name == '':
|
if not output_name == '':
|
||||||
run_cmd += f' --output_name="{output_name}"'
|
run_cmd += f' --output_name="{output_name}"'
|
||||||
|
if (int(max_token_length) > 75):
|
||||||
|
run_cmd += f' --max_token_length={max_token_length}'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -694,6 +699,15 @@ def dreambooth_tab(
|
|||||||
)
|
)
|
||||||
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||||
vae_button.click(get_any_file_path, outputs=vae)
|
vae_button.click(get_any_file_path, outputs=vae)
|
||||||
|
max_token_length = gr.Dropdown(
|
||||||
|
label='Max Token Length',
|
||||||
|
choices=[
|
||||||
|
'75',
|
||||||
|
'150',
|
||||||
|
'225',
|
||||||
|
],
|
||||||
|
value='75',
|
||||||
|
)
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'This section provide Dreambooth tools to help setup your dataset...'
|
'This section provide Dreambooth tools to help setup your dataset...'
|
||||||
@ -745,6 +759,7 @@ def dreambooth_tab(
|
|||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
|
max_token_length,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -66,6 +66,7 @@ def save_configuration(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
output_name,
|
output_name,
|
||||||
|
max_token_length,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -148,6 +149,7 @@ def open_config_file(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
output_name,
|
output_name,
|
||||||
|
max_token_length,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -215,6 +217,7 @@ def train_model(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
output_name,
|
output_name,
|
||||||
|
max_token_length,
|
||||||
):
|
):
|
||||||
# create caption json file
|
# create caption json file
|
||||||
if generate_caption_database:
|
if generate_caption_database:
|
||||||
@ -331,6 +334,8 @@ def train_model(
|
|||||||
run_cmd += f' --resume={resume}'
|
run_cmd += f' --resume={resume}'
|
||||||
if not output_name == '':
|
if not output_name == '':
|
||||||
run_cmd += f' --output_name="{output_name}"'
|
run_cmd += f' --output_name="{output_name}"'
|
||||||
|
if (int(max_token_length) > 75):
|
||||||
|
run_cmd += f' --max_token_length={max_token_length}'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -642,6 +647,15 @@ def finetune_tab():
|
|||||||
gradient_accumulation_steps = gr.Number(
|
gradient_accumulation_steps = gr.Number(
|
||||||
label='Gradient accumulate steps', value='1'
|
label='Gradient accumulate steps', value='1'
|
||||||
)
|
)
|
||||||
|
max_token_length = gr.Dropdown(
|
||||||
|
label='Max Token Length',
|
||||||
|
choices=[
|
||||||
|
'75',
|
||||||
|
'150',
|
||||||
|
'225',
|
||||||
|
],
|
||||||
|
value='75',
|
||||||
|
)
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
create_caption = gr.Checkbox(
|
create_caption = gr.Checkbox(
|
||||||
@ -695,6 +709,7 @@ def finetune_tab():
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
output_name,
|
output_name,
|
||||||
|
max_token_length,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
18
lora_gui.py
18
lora_gui.py
@ -80,6 +80,7 @@ def save_configuration(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
output_name,
|
output_name,
|
||||||
model_list,
|
model_list,
|
||||||
|
max_token_length,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -163,6 +164,7 @@ def open_configuration(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
output_name,
|
output_name,
|
||||||
model_list,
|
model_list,
|
||||||
|
max_token_length,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -229,7 +231,8 @@ def train_model(
|
|||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
output_name,
|
output_name,
|
||||||
model_list,
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
|
max_token_length,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -405,6 +408,8 @@ def train_model(
|
|||||||
# run_cmd += f' --vae="{vae}"'
|
# run_cmd += f' --vae="{vae}"'
|
||||||
if not output_name == '':
|
if not output_name == '':
|
||||||
run_cmd += f' --output_name="{output_name}"'
|
run_cmd += f' --output_name="{output_name}"'
|
||||||
|
if (int(max_token_length) > 75):
|
||||||
|
run_cmd += f' --max_token_length={max_token_length}'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -781,6 +786,16 @@ def lora_tab(
|
|||||||
# )
|
# )
|
||||||
# vae_button = gr.Button('📂', elem_id='open_folder_small')
|
# vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||||
# vae_button.click(get_any_file_path, outputs=vae)
|
# vae_button.click(get_any_file_path, outputs=vae)
|
||||||
|
max_token_length = gr.Dropdown(
|
||||||
|
label='Max Token Length',
|
||||||
|
choices=[
|
||||||
|
'75',
|
||||||
|
'150',
|
||||||
|
'225',
|
||||||
|
],
|
||||||
|
value='75',
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'This section provide Dreambooth tools to help setup your dataset...'
|
'This section provide Dreambooth tools to help setup your dataset...'
|
||||||
@ -839,6 +854,7 @@ def lora_tab(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
output_name,
|
output_name,
|
||||||
model_list,
|
model_list,
|
||||||
|
max_token_length
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
Loading…
Reference in New Issue
Block a user