Add support for --keep_token option

This commit is contained in:
bmaltais 2023-01-27 07:33:44 -05:00
parent bf371b49bf
commit 202923b3ce
6 changed files with 33 additions and 11 deletions

View File

@ -132,6 +132,7 @@ Then redo the installation instruction within the kohya_ss venv.
## Change history
* 2023/01/27 (v20.5.1):
- Fix issue: https://github.com/bmaltais/kohya_ss/issues/70
- Fix issue https://github.com/bmaltais/kohya_ss/issues/71
* 2023/01/26 (v20.5.0):
- Add new `Dreambooth TI` tab for training of Textual Inversion embeddings

View File

@ -82,7 +82,7 @@ def save_configuration(
max_data_loader_n_workers,
mem_eff_attn,
gradient_accumulation_steps,
model_list,
model_list, keep_tokens,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -166,7 +166,7 @@ def open_configuration(
max_data_loader_n_workers,
mem_eff_attn,
gradient_accumulation_steps,
model_list,
model_list, keep_tokens,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -235,6 +235,7 @@ def train_model(
mem_eff_attn,
gradient_accumulation_steps,
model_list, # Keep this. Yes, it is unused here but required given the common list used
keep_tokens,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -396,6 +397,7 @@ def train_model(
full_fp16=full_fp16,
xformers=xformers,
use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
)
print(run_cmd)
@ -602,6 +604,7 @@ def dreambooth_tab(
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
keep_tokens,
) = gradio_advanced_training()
color_aug.change(
color_aug_changed,
@ -665,6 +668,7 @@ def dreambooth_tab(
mem_eff_attn,
gradient_accumulation_steps,
model_list,
keep_tokens,
]
button_open_config.click(

View File

@ -78,7 +78,7 @@ def save_configuration(
color_aug,
model_list,
cache_latents,
use_latent_files,
use_latent_files, keep_tokens,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -168,7 +168,7 @@ def open_config_file(
color_aug,
model_list,
cache_latents,
use_latent_files,
use_latent_files, keep_tokens,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -243,7 +243,7 @@ def train_model(
color_aug,
model_list, # Keep this. Yes, it is unused here but required given the common list used
cache_latents,
use_latent_files,
use_latent_files, keep_tokens,
):
# create caption json file
if generate_caption_database:
@ -381,6 +381,7 @@ def train_model(
full_fp16=full_fp16,
xformers=xformers,
use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
)
print(run_cmd)
@ -585,6 +586,7 @@ def finetune_tab():
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
keep_tokens,
) = gradio_advanced_training()
color_aug.change(
color_aug_changed,
@ -644,6 +646,7 @@ def finetune_tab():
model_list,
cache_latents,
use_latent_files,
keep_tokens,
]
button_run.click(train_model, inputs=settings_list)

View File

@ -519,6 +519,9 @@ def gradio_advanced_training():
shuffle_caption = gr.Checkbox(
label='Shuffle caption', value=False
)
keep_tokens = gr.Slider(
label='Keen n tokens', value='0', minimum=0, maximum=32, step=1
)
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True)
with gr.Row():
@ -572,6 +575,7 @@ def gradio_advanced_training():
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
keep_tokens,
)
def run_cmd_advanced_training(**kwargs):
@ -596,6 +600,10 @@ def run_cmd_advanced_training(**kwargs):
if kwargs.get('resume')
else '',
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
if int(kwargs.get('keep_tokens', 0)) > 0
else '',
' --save_state' if kwargs.get('save_state') else '',
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',

View File

@ -90,7 +90,7 @@ def save_configuration(
max_train_epochs,
max_data_loader_n_workers,
network_alpha,
training_comment,
training_comment, keep_tokens,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -179,7 +179,7 @@ def open_configuration(
max_train_epochs,
max_data_loader_n_workers,
network_alpha,
training_comment,
training_comment, keep_tokens,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -252,7 +252,7 @@ def train_model(
max_train_epochs,
max_data_loader_n_workers,
network_alpha,
training_comment,
training_comment, keep_tokens,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -425,6 +425,7 @@ def train_model(
full_fp16=full_fp16,
xformers=xformers,
use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
)
print(run_cmd)
@ -660,6 +661,7 @@ def lora_tab(
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
keep_tokens,
) = gradio_advanced_training()
color_aug.change(
color_aug_changed,
@ -733,6 +735,7 @@ def lora_tab(
max_data_loader_n_workers,
network_alpha,
training_comment,
keep_tokens,
]
button_open_config.click(

View File

@ -82,7 +82,7 @@ def save_configuration(
max_data_loader_n_workers,
mem_eff_attn,
gradient_accumulation_steps,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -170,7 +170,7 @@ def open_configuration(
max_data_loader_n_workers,
mem_eff_attn,
gradient_accumulation_steps,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -239,7 +239,7 @@ def train_model(
mem_eff_attn,
gradient_accumulation_steps,
model_list, # Keep this. Yes, it is unused here but required given the common list used
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -416,6 +416,7 @@ def train_model(
full_fp16=full_fp16,
xformers=xformers,
use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
)
run_cmd += f' --token_string={token_string}'
run_cmd += f' --init_word={init_word}'
@ -669,6 +670,7 @@ def ti_tab(
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
keep_tokens,
) = gradio_advanced_training()
color_aug.change(
color_aug_changed,
@ -733,6 +735,7 @@ def ti_tab(
gradient_accumulation_steps,
model_list,
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
keep_tokens,
]
button_open_config.click(