Add support for --keep_token option
This commit is contained in:
parent
bf371b49bf
commit
202923b3ce
@ -132,6 +132,7 @@ Then redo the installation instruction within the kohya_ss venv.
|
||||
## Change history
|
||||
|
||||
* 2023/01/27 (v20.5.1):
|
||||
- Fix issue: https://github.com/bmaltais/kohya_ss/issues/70
|
||||
- Fix issue https://github.com/bmaltais/kohya_ss/issues/71
|
||||
* 2023/01/26 (v20.5.0):
|
||||
- Add new `Dreambooth TI` tab for training of Textual Inversion embeddings
|
||||
|
@ -82,7 +82,7 @@ def save_configuration(
|
||||
max_data_loader_n_workers,
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list,
|
||||
model_list, keep_tokens,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -166,7 +166,7 @@ def open_configuration(
|
||||
max_data_loader_n_workers,
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list,
|
||||
model_list, keep_tokens,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -235,6 +235,7 @@ def train_model(
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||
keep_tokens,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -396,6 +397,7 @@ def train_model(
|
||||
full_fp16=full_fp16,
|
||||
xformers=xformers,
|
||||
use_8bit_adam=use_8bit_adam,
|
||||
keep_tokens=keep_tokens,
|
||||
)
|
||||
|
||||
print(run_cmd)
|
||||
@ -602,6 +604,7 @@ def dreambooth_tab(
|
||||
max_token_length,
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -665,6 +668,7 @@ def dreambooth_tab(
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list,
|
||||
keep_tokens,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
@ -78,7 +78,7 @@ def save_configuration(
|
||||
color_aug,
|
||||
model_list,
|
||||
cache_latents,
|
||||
use_latent_files,
|
||||
use_latent_files, keep_tokens,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -168,7 +168,7 @@ def open_config_file(
|
||||
color_aug,
|
||||
model_list,
|
||||
cache_latents,
|
||||
use_latent_files,
|
||||
use_latent_files, keep_tokens,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -243,7 +243,7 @@ def train_model(
|
||||
color_aug,
|
||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||
cache_latents,
|
||||
use_latent_files,
|
||||
use_latent_files, keep_tokens,
|
||||
):
|
||||
# create caption json file
|
||||
if generate_caption_database:
|
||||
@ -381,6 +381,7 @@ def train_model(
|
||||
full_fp16=full_fp16,
|
||||
xformers=xformers,
|
||||
use_8bit_adam=use_8bit_adam,
|
||||
keep_tokens=keep_tokens,
|
||||
)
|
||||
|
||||
print(run_cmd)
|
||||
@ -585,6 +586,7 @@ def finetune_tab():
|
||||
max_token_length,
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -644,6 +646,7 @@ def finetune_tab():
|
||||
model_list,
|
||||
cache_latents,
|
||||
use_latent_files,
|
||||
keep_tokens,
|
||||
]
|
||||
|
||||
button_run.click(train_model, inputs=settings_list)
|
||||
|
@ -519,6 +519,9 @@ def gradio_advanced_training():
|
||||
shuffle_caption = gr.Checkbox(
|
||||
label='Shuffle caption', value=False
|
||||
)
|
||||
keep_tokens = gr.Slider(
|
||||
label='Keen n tokens', value='0', minimum=0, maximum=32, step=1
|
||||
)
|
||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||
with gr.Row():
|
||||
@ -572,6 +575,7 @@ def gradio_advanced_training():
|
||||
max_token_length,
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
)
|
||||
|
||||
def run_cmd_advanced_training(**kwargs):
|
||||
@ -596,6 +600,10 @@ def run_cmd_advanced_training(**kwargs):
|
||||
if kwargs.get('resume')
|
||||
else '',
|
||||
|
||||
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
|
||||
if int(kwargs.get('keep_tokens', 0)) > 0
|
||||
else '',
|
||||
|
||||
' --save_state' if kwargs.get('save_state') else '',
|
||||
|
||||
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
||||
|
@ -90,7 +90,7 @@ def save_configuration(
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
network_alpha,
|
||||
training_comment,
|
||||
training_comment, keep_tokens,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -179,7 +179,7 @@ def open_configuration(
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
network_alpha,
|
||||
training_comment,
|
||||
training_comment, keep_tokens,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -252,7 +252,7 @@ def train_model(
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
network_alpha,
|
||||
training_comment,
|
||||
training_comment, keep_tokens,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -425,6 +425,7 @@ def train_model(
|
||||
full_fp16=full_fp16,
|
||||
xformers=xformers,
|
||||
use_8bit_adam=use_8bit_adam,
|
||||
keep_tokens=keep_tokens,
|
||||
)
|
||||
|
||||
print(run_cmd)
|
||||
@ -660,6 +661,7 @@ def lora_tab(
|
||||
max_token_length,
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -733,6 +735,7 @@ def lora_tab(
|
||||
max_data_loader_n_workers,
|
||||
network_alpha,
|
||||
training_comment,
|
||||
keep_tokens,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
@ -82,7 +82,7 @@ def save_configuration(
|
||||
max_data_loader_n_workers,
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -170,7 +170,7 @@ def open_configuration(
|
||||
max_data_loader_n_workers,
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -239,7 +239,7 @@ def train_model(
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -416,6 +416,7 @@ def train_model(
|
||||
full_fp16=full_fp16,
|
||||
xformers=xformers,
|
||||
use_8bit_adam=use_8bit_adam,
|
||||
keep_tokens=keep_tokens,
|
||||
)
|
||||
run_cmd += f' --token_string={token_string}'
|
||||
run_cmd += f' --init_word={init_word}'
|
||||
@ -669,6 +670,7 @@ def ti_tab(
|
||||
max_token_length,
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -733,6 +735,7 @@ def ti_tab(
|
||||
gradient_accumulation_steps,
|
||||
model_list,
|
||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
||||
keep_tokens,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
Loading…
Reference in New Issue
Block a user