Merge pull request #78 from bmaltais/dev

Dev
This commit is contained in:
bmaltais 2023-01-27 19:46:55 -05:00 committed by GitHub
commit fdd1b02a26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 12 deletions

View File

@ -2,6 +2,16 @@
This repository repository is providing a Gradio GUI for kohya's Stable Diffusion trainers found here: https://github.com/kohya-ss/sd-scripts. The GUI allow you to set the training parameters and generate and run the required CLI command to train the model. This repository repository is providing a Gradio GUI for kohya's Stable Diffusion trainers found here: https://github.com/kohya-ss/sd-scripts. The GUI allow you to set the training parameters and generate and run the required CLI command to train the model.
## Tutorials
How to create a LoRA part 1, dataset preparation:
[![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/N4_-fB62Hwk/0.jpg)](https://www.youtube.com/watch?v=N4_-fB62Hwk)
How to create a LoRA part 2, training the model:
[![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/k5imq01uvUY/0.jpg)](https://www.youtube.com/watch?v=k5imq01uvUY)
## Required Dependencies ## Required Dependencies
Python 3.10.6+ and Git: Python 3.10.6+ and Git:
@ -131,6 +141,9 @@ Then redo the installation instruction within the kohya_ss venv.
## Change history ## Change history
* 2023/01/27 (v20.5.1):
- Fix issue: https://github.com/bmaltais/kohya_ss/issues/70
- Fix issue https://github.com/bmaltais/kohya_ss/issues/71
* 2023/01/26 (v20.5.0): * 2023/01/26 (v20.5.0):
- Add new `Dreambooth TI` tab for training of Textual Inversion embeddings - Add new `Dreambooth TI` tab for training of Textual Inversion embeddings
- Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.) - Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.)

View File

@ -82,7 +82,7 @@ def save_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, model_list, keep_tokens,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -166,7 +166,7 @@ def open_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, model_list, keep_tokens,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -235,6 +235,7 @@ def train_model(
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
keep_tokens,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -396,6 +397,7 @@ def train_model(
full_fp16=full_fp16, full_fp16=full_fp16,
xformers=xformers, xformers=xformers,
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
) )
print(run_cmd) print(run_cmd)
@ -602,6 +604,7 @@ def dreambooth_tab(
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -665,6 +668,7 @@ def dreambooth_tab(
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, model_list,
keep_tokens,
] ]
button_open_config.click( button_open_config.click(

View File

@ -78,7 +78,7 @@ def save_configuration(
color_aug, color_aug,
model_list, model_list,
cache_latents, cache_latents,
use_latent_files, use_latent_files, keep_tokens,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -168,7 +168,7 @@ def open_config_file(
color_aug, color_aug,
model_list, model_list,
cache_latents, cache_latents,
use_latent_files, use_latent_files, keep_tokens,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -243,7 +243,7 @@ def train_model(
color_aug, color_aug,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
cache_latents, cache_latents,
use_latent_files, use_latent_files, keep_tokens,
): ):
# create caption json file # create caption json file
if generate_caption_database: if generate_caption_database:
@ -381,6 +381,7 @@ def train_model(
full_fp16=full_fp16, full_fp16=full_fp16,
xformers=xformers, xformers=xformers,
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
) )
print(run_cmd) print(run_cmd)
@ -585,6 +586,7 @@ def finetune_tab():
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -644,6 +646,7 @@ def finetune_tab():
model_list, model_list,
cache_latents, cache_latents,
use_latent_files, use_latent_files,
keep_tokens,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)

View File

@ -519,6 +519,9 @@ def gradio_advanced_training():
shuffle_caption = gr.Checkbox( shuffle_caption = gr.Checkbox(
label='Shuffle caption', value=False label='Shuffle caption', value=False
) )
keep_tokens = gr.Slider(
label='Keen n tokens', value='0', minimum=0, maximum=32, step=1
)
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True) xformers = gr.Checkbox(label='Use xformers', value=True)
with gr.Row(): with gr.Row():
@ -572,6 +575,7 @@ def gradio_advanced_training():
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens,
) )
def run_cmd_advanced_training(**kwargs): def run_cmd_advanced_training(**kwargs):
@ -596,6 +600,10 @@ def run_cmd_advanced_training(**kwargs):
if kwargs.get('resume') if kwargs.get('resume')
else '', else '',
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
if int(kwargs.get('keep_tokens', 0)) > 0
else '',
' --save_state' if kwargs.get('save_state') else '', ' --save_state' if kwargs.get('save_state') else '',
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '', ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',

View File

@ -90,7 +90,7 @@ def save_configuration(
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, training_comment, keep_tokens,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -159,7 +159,7 @@ def open_configuration(
stop_text_encoder_training, stop_text_encoder_training,
use_8bit_adam, use_8bit_adam,
xformers, xformers,
save_model_as_dropdown, save_model_as,
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
@ -179,7 +179,7 @@ def open_configuration(
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, training_comment, keep_tokens,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -252,7 +252,7 @@ def train_model(
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, training_comment, keep_tokens,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -425,6 +425,7 @@ def train_model(
full_fp16=full_fp16, full_fp16=full_fp16,
xformers=xformers, xformers=xformers,
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
) )
print(run_cmd) print(run_cmd)
@ -660,6 +661,7 @@ def lora_tab(
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -733,6 +735,7 @@ def lora_tab(
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, training_comment,
keep_tokens,
] ]
button_open_config.click( button_open_config.click(

View File

@ -82,7 +82,7 @@ def save_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -104,6 +104,10 @@ def save_configuration(
if file_path == None or file_path == '': if file_path == None or file_path == '':
return original_file_path # In case a file_path was provided and the user decide to cancel the open action return original_file_path # In case a file_path was provided and the user decide to cancel the open action
directory = os.path.dirname(file_path)
if not os.path.exists(directory):
os.makedirs(directory)
# Return the values of the variables as a dictionary # Return the values of the variables as a dictionary
variables = { variables = {
name: value name: value
@ -166,7 +170,7 @@ def open_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -235,7 +239,7 @@ def train_model(
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -412,6 +416,7 @@ def train_model(
full_fp16=full_fp16, full_fp16=full_fp16,
xformers=xformers, xformers=xformers,
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
) )
run_cmd += f' --token_string={token_string}' run_cmd += f' --token_string={token_string}'
run_cmd += f' --init_word={init_word}' run_cmd += f' --init_word={init_word}'
@ -665,6 +670,7 @@ def ti_tab(
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -729,6 +735,7 @@ def ti_tab(
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, model_list,
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
keep_tokens,
] ]
button_open_config.click( button_open_config.click(