Add support for max token
This commit is contained in:
parent
42a3646d4a
commit
43116feda8
@ -1,8 +1,6 @@
|
||||
# Kohya's dreambooth and finetuning
|
||||
# Kohya's GUI
|
||||
|
||||
This repository now includes the solutions provided by Kohya_ss in a single location. I have combined both solutions under one repository to align with the new official Kohya repository where he will maintain his code from now on: https://github.com/kohya-ss/sd-scripts.
|
||||
|
||||
A note accompanying the release of his new repository can be found here: https://note.com/kohya_ss/n/nba4eceaa4594
|
||||
This repository repository is providing a Gradio GUI for kohya's Stable Diffusion trainers found here: https://github.com/kohya-ss/sd-scripts. The GUI allow you to set the training parameters and generate and run the required CLI command to train the model.
|
||||
|
||||
## Required Dependencies
|
||||
|
||||
@ -101,6 +99,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i
|
||||
|
||||
## Change history
|
||||
|
||||
* 2023/01/11 (v20.2.0):
|
||||
- Add support for max token lenght
|
||||
* 2023/01/10 (v20.1.1):
|
||||
- Fix issue with LoRA config loading
|
||||
* 2023/01/10 (v20.1):
|
||||
|
@ -73,6 +73,7 @@ def save_configuration(
|
||||
clip_skip,
|
||||
vae,
|
||||
output_name,
|
||||
max_token_length,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -151,6 +152,7 @@ def open_configuration(
|
||||
clip_skip,
|
||||
vae,
|
||||
output_name,
|
||||
max_token_length,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -213,6 +215,7 @@ def train_model(
|
||||
clip_skip,
|
||||
vae,
|
||||
output_name,
|
||||
max_token_length,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -367,6 +370,8 @@ def train_model(
|
||||
run_cmd += f' --vae="{vae}"'
|
||||
if not output_name == '':
|
||||
run_cmd += f' --output_name="{output_name}"'
|
||||
if (int(max_token_length) > 75):
|
||||
run_cmd += f' --max_token_length={max_token_length}'
|
||||
|
||||
print(run_cmd)
|
||||
# Run the command
|
||||
@ -694,6 +699,15 @@ def dreambooth_tab(
|
||||
)
|
||||
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
vae_button.click(get_any_file_path, outputs=vae)
|
||||
max_token_length = gr.Dropdown(
|
||||
label='Max Token Length',
|
||||
choices=[
|
||||
'75',
|
||||
'150',
|
||||
'225',
|
||||
],
|
||||
value='75',
|
||||
)
|
||||
with gr.Tab('Tools'):
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
@ -745,6 +759,7 @@ def dreambooth_tab(
|
||||
clip_skip,
|
||||
vae,
|
||||
output_name,
|
||||
max_token_length,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
@ -66,6 +66,7 @@ def save_configuration(
|
||||
mem_eff_attn,
|
||||
shuffle_caption,
|
||||
output_name,
|
||||
max_token_length,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -148,6 +149,7 @@ def open_config_file(
|
||||
mem_eff_attn,
|
||||
shuffle_caption,
|
||||
output_name,
|
||||
max_token_length,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -215,6 +217,7 @@ def train_model(
|
||||
mem_eff_attn,
|
||||
shuffle_caption,
|
||||
output_name,
|
||||
max_token_length,
|
||||
):
|
||||
# create caption json file
|
||||
if generate_caption_database:
|
||||
@ -331,6 +334,8 @@ def train_model(
|
||||
run_cmd += f' --resume={resume}'
|
||||
if not output_name == '':
|
||||
run_cmd += f' --output_name="{output_name}"'
|
||||
if (int(max_token_length) > 75):
|
||||
run_cmd += f' --max_token_length={max_token_length}'
|
||||
|
||||
print(run_cmd)
|
||||
# Run the command
|
||||
@ -642,6 +647,15 @@ def finetune_tab():
|
||||
gradient_accumulation_steps = gr.Number(
|
||||
label='Gradient accumulate steps', value='1'
|
||||
)
|
||||
max_token_length = gr.Dropdown(
|
||||
label='Max Token Length',
|
||||
choices=[
|
||||
'75',
|
||||
'150',
|
||||
'225',
|
||||
],
|
||||
value='75',
|
||||
)
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
create_caption = gr.Checkbox(
|
||||
@ -695,6 +709,7 @@ def finetune_tab():
|
||||
mem_eff_attn,
|
||||
shuffle_caption,
|
||||
output_name,
|
||||
max_token_length,
|
||||
]
|
||||
|
||||
button_run.click(train_model, inputs=settings_list)
|
||||
|
18
lora_gui.py
18
lora_gui.py
@ -80,6 +80,7 @@ def save_configuration(
|
||||
mem_eff_attn,
|
||||
output_name,
|
||||
model_list,
|
||||
max_token_length,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -163,6 +164,7 @@ def open_configuration(
|
||||
mem_eff_attn,
|
||||
output_name,
|
||||
model_list,
|
||||
max_token_length,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -229,7 +231,8 @@ def train_model(
|
||||
gradient_accumulation_steps,
|
||||
mem_eff_attn,
|
||||
output_name,
|
||||
model_list,
|
||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||
max_token_length,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -405,6 +408,8 @@ def train_model(
|
||||
# run_cmd += f' --vae="{vae}"'
|
||||
if not output_name == '':
|
||||
run_cmd += f' --output_name="{output_name}"'
|
||||
if (int(max_token_length) > 75):
|
||||
run_cmd += f' --max_token_length={max_token_length}'
|
||||
|
||||
print(run_cmd)
|
||||
# Run the command
|
||||
@ -781,6 +786,16 @@ def lora_tab(
|
||||
# )
|
||||
# vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
# vae_button.click(get_any_file_path, outputs=vae)
|
||||
max_token_length = gr.Dropdown(
|
||||
label='Max Token Length',
|
||||
choices=[
|
||||
'75',
|
||||
'150',
|
||||
'225',
|
||||
],
|
||||
value='75',
|
||||
)
|
||||
|
||||
with gr.Tab('Tools'):
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
@ -839,6 +854,7 @@ def lora_tab(
|
||||
mem_eff_attn,
|
||||
output_name,
|
||||
model_list,
|
||||
max_token_length
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
Loading…
Reference in New Issue
Block a user