commit
fdd1b02a26
13
README.md
13
README.md
@ -2,6 +2,16 @@
|
|||||||
|
|
||||||
This repository repository is providing a Gradio GUI for kohya's Stable Diffusion trainers found here: https://github.com/kohya-ss/sd-scripts. The GUI allow you to set the training parameters and generate and run the required CLI command to train the model.
|
This repository repository is providing a Gradio GUI for kohya's Stable Diffusion trainers found here: https://github.com/kohya-ss/sd-scripts. The GUI allow you to set the training parameters and generate and run the required CLI command to train the model.
|
||||||
|
|
||||||
|
## Tutorials
|
||||||
|
|
||||||
|
How to create a LoRA part 1, dataset preparation:
|
||||||
|
|
||||||
|
[![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/N4_-fB62Hwk/0.jpg)](https://www.youtube.com/watch?v=N4_-fB62Hwk)
|
||||||
|
|
||||||
|
How to create a LoRA part 2, training the model:
|
||||||
|
|
||||||
|
[![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/k5imq01uvUY/0.jpg)](https://www.youtube.com/watch?v=k5imq01uvUY)
|
||||||
|
|
||||||
## Required Dependencies
|
## Required Dependencies
|
||||||
|
|
||||||
Python 3.10.6+ and Git:
|
Python 3.10.6+ and Git:
|
||||||
@ -131,6 +141,9 @@ Then redo the installation instruction within the kohya_ss venv.
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 2023/01/27 (v20.5.1):
|
||||||
|
- Fix issue: https://github.com/bmaltais/kohya_ss/issues/70
|
||||||
|
- Fix issue https://github.com/bmaltais/kohya_ss/issues/71
|
||||||
* 2023/01/26 (v20.5.0):
|
* 2023/01/26 (v20.5.0):
|
||||||
- Add new `Dreambooth TI` tab for training of Textual Inversion embeddings
|
- Add new `Dreambooth TI` tab for training of Textual Inversion embeddings
|
||||||
- Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.)
|
- Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.)
|
||||||
|
@ -82,7 +82,7 @@ def save_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list, keep_tokens,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -166,7 +166,7 @@ def open_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list, keep_tokens,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -235,6 +235,7 @@ def train_model(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
|
keep_tokens,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -396,6 +397,7 @@ def train_model(
|
|||||||
full_fp16=full_fp16,
|
full_fp16=full_fp16,
|
||||||
xformers=xformers,
|
xformers=xformers,
|
||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
|
keep_tokens=keep_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -602,6 +604,7 @@ def dreambooth_tab(
|
|||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
|
keep_tokens,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -665,6 +668,7 @@ def dreambooth_tab(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list,
|
||||||
|
keep_tokens,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -78,7 +78,7 @@ def save_configuration(
|
|||||||
color_aug,
|
color_aug,
|
||||||
model_list,
|
model_list,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files,
|
use_latent_files, keep_tokens,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -168,7 +168,7 @@ def open_config_file(
|
|||||||
color_aug,
|
color_aug,
|
||||||
model_list,
|
model_list,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files,
|
use_latent_files, keep_tokens,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -243,7 +243,7 @@ def train_model(
|
|||||||
color_aug,
|
color_aug,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files,
|
use_latent_files, keep_tokens,
|
||||||
):
|
):
|
||||||
# create caption json file
|
# create caption json file
|
||||||
if generate_caption_database:
|
if generate_caption_database:
|
||||||
@ -381,6 +381,7 @@ def train_model(
|
|||||||
full_fp16=full_fp16,
|
full_fp16=full_fp16,
|
||||||
xformers=xformers,
|
xformers=xformers,
|
||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
|
keep_tokens=keep_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -585,6 +586,7 @@ def finetune_tab():
|
|||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
|
keep_tokens,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -644,6 +646,7 @@ def finetune_tab():
|
|||||||
model_list,
|
model_list,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files,
|
use_latent_files,
|
||||||
|
keep_tokens,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
@ -519,6 +519,9 @@ def gradio_advanced_training():
|
|||||||
shuffle_caption = gr.Checkbox(
|
shuffle_caption = gr.Checkbox(
|
||||||
label='Shuffle caption', value=False
|
label='Shuffle caption', value=False
|
||||||
)
|
)
|
||||||
|
keep_tokens = gr.Slider(
|
||||||
|
label='Keen n tokens', value='0', minimum=0, maximum=32, step=1
|
||||||
|
)
|
||||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -572,6 +575,7 @@ def gradio_advanced_training():
|
|||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
|
keep_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_cmd_advanced_training(**kwargs):
|
def run_cmd_advanced_training(**kwargs):
|
||||||
@ -596,6 +600,10 @@ def run_cmd_advanced_training(**kwargs):
|
|||||||
if kwargs.get('resume')
|
if kwargs.get('resume')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
|
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
|
||||||
|
if int(kwargs.get('keep_tokens', 0)) > 0
|
||||||
|
else '',
|
||||||
|
|
||||||
' --save_state' if kwargs.get('save_state') else '',
|
' --save_state' if kwargs.get('save_state') else '',
|
||||||
|
|
||||||
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
||||||
|
11
lora_gui.py
11
lora_gui.py
@ -90,7 +90,7 @@ def save_configuration(
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment,
|
training_comment, keep_tokens,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -159,7 +159,7 @@ def open_configuration(
|
|||||||
stop_text_encoder_training,
|
stop_text_encoder_training,
|
||||||
use_8bit_adam,
|
use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as_dropdown,
|
save_model_as,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
@ -179,7 +179,7 @@ def open_configuration(
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment,
|
training_comment, keep_tokens,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -252,7 +252,7 @@ def train_model(
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment,
|
training_comment, keep_tokens,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -425,6 +425,7 @@ def train_model(
|
|||||||
full_fp16=full_fp16,
|
full_fp16=full_fp16,
|
||||||
xformers=xformers,
|
xformers=xformers,
|
||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
|
keep_tokens=keep_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -660,6 +661,7 @@ def lora_tab(
|
|||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
|
keep_tokens,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -733,6 +735,7 @@ def lora_tab(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment,
|
training_comment,
|
||||||
|
keep_tokens,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -82,7 +82,7 @@ def save_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -104,6 +104,10 @@ def save_configuration(
|
|||||||
if file_path == None or file_path == '':
|
if file_path == None or file_path == '':
|
||||||
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
|
|
||||||
|
directory = os.path.dirname(file_path)
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory)
|
||||||
|
|
||||||
# Return the values of the variables as a dictionary
|
# Return the values of the variables as a dictionary
|
||||||
variables = {
|
variables = {
|
||||||
name: value
|
name: value
|
||||||
@ -166,7 +170,7 @@ def open_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -235,7 +239,7 @@ def train_model(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -412,6 +416,7 @@ def train_model(
|
|||||||
full_fp16=full_fp16,
|
full_fp16=full_fp16,
|
||||||
xformers=xformers,
|
xformers=xformers,
|
||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
|
keep_tokens=keep_tokens,
|
||||||
)
|
)
|
||||||
run_cmd += f' --token_string={token_string}'
|
run_cmd += f' --token_string={token_string}'
|
||||||
run_cmd += f' --init_word={init_word}'
|
run_cmd += f' --init_word={init_word}'
|
||||||
@ -665,6 +670,7 @@ def ti_tab(
|
|||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
|
keep_tokens,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -729,6 +735,7 @@ def ti_tab(
|
|||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list,
|
||||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
||||||
|
keep_tokens,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
Loading…
Reference in New Issue
Block a user