Implement open and save config for LoRA
This commit is contained in:
parent
0f42ab78c4
commit
b44f075f60
@ -20,6 +20,21 @@ def get_file_path(file_path='', defaultextension='.json'):
|
|||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
def get_any_file_path(file_path=''):
|
||||||
|
current_file_path = file_path
|
||||||
|
# print(f'current file path: {current_file_path}')
|
||||||
|
|
||||||
|
root = Tk()
|
||||||
|
root.wm_attributes('-topmost', 1)
|
||||||
|
root.withdraw()
|
||||||
|
file_path = filedialog.askopenfilename()
|
||||||
|
root.destroy()
|
||||||
|
|
||||||
|
if file_path == '':
|
||||||
|
file_path = current_file_path
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
def remove_doublequote(file_path):
|
def remove_doublequote(file_path):
|
||||||
if file_path != None:
|
if file_path != None:
|
||||||
|
247
lora_gui.py
247
lora_gui.py
@ -15,6 +15,7 @@ from library.common_gui import (
|
|||||||
get_folder_path,
|
get_folder_path,
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
|
get_any_file_path,
|
||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
@ -64,7 +65,7 @@ def save_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
@ -117,6 +118,10 @@ def save_configuration(
|
|||||||
'save_state': save_state,
|
'save_state': save_state,
|
||||||
'resume': resume,
|
'resume': resume,
|
||||||
'prior_loss_weight': prior_loss_weight,
|
'prior_loss_weight': prior_loss_weight,
|
||||||
|
'text_encoder_lr': text_encoder_lr,
|
||||||
|
'unet_lr': unet_lr,
|
||||||
|
'network_train': network_train,
|
||||||
|
'network_dim': network_dim
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save the data to the selected file
|
# Save the data to the selected file
|
||||||
@ -159,7 +164,7 @@ def open_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
||||||
):
|
):
|
||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
@ -213,6 +218,10 @@ def open_configuration(
|
|||||||
my_data.get('save_state', save_state),
|
my_data.get('save_state', save_state),
|
||||||
my_data.get('resume', resume),
|
my_data.get('resume', resume),
|
||||||
my_data.get('prior_loss_weight', prior_loss_weight),
|
my_data.get('prior_loss_weight', prior_loss_weight),
|
||||||
|
my_data.get('text_encoder_lr', text_encoder_lr),
|
||||||
|
my_data.get('unet_lr', unet_lr),
|
||||||
|
my_data.get('network_train', network_train),
|
||||||
|
my_data.get('network_dim', network_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -548,10 +557,10 @@ def lora_tab(
|
|||||||
label='Pretrained model name or path',
|
label='Pretrained model name or path',
|
||||||
placeholder='enter the path to custom model or name of pretrained model',
|
placeholder='enter the path to custom model or name of pretrained model',
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_fille = gr.Button(
|
pretrained_model_name_or_path_file = gr.Button(
|
||||||
document_symbol, elem_id='open_folder_small'
|
document_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_fille.click(
|
pretrained_model_name_or_path_file.click(
|
||||||
get_file_path,
|
get_file_path,
|
||||||
inputs=[pretrained_model_name_or_path_input],
|
inputs=[pretrained_model_name_or_path_input],
|
||||||
outputs=pretrained_model_name_or_path_input,
|
outputs=pretrained_model_name_or_path_input,
|
||||||
@ -586,6 +595,19 @@ def lora_tab(
|
|||||||
],
|
],
|
||||||
value='same as source model',
|
value='same as source model',
|
||||||
)
|
)
|
||||||
|
with gr.Row():
|
||||||
|
lora_network_weights = gr.Textbox(
|
||||||
|
label='LoRA network weights',
|
||||||
|
placeholder='{Optional) Path to existing LoRA network weights to resume training}',
|
||||||
|
)
|
||||||
|
lora_network_weights_file = gr.Button(
|
||||||
|
document_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
lora_network_weights_file.click(
|
||||||
|
get_any_file_path,
|
||||||
|
inputs=[lora_network_weights],
|
||||||
|
outputs=lora_network_weights,
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
v2_input = gr.Checkbox(label='v2', value=True)
|
v2_input = gr.Checkbox(label='v2', value=True)
|
||||||
v_parameterization_input = gr.Checkbox(
|
v_parameterization_input = gr.Checkbox(
|
||||||
@ -812,200 +834,63 @@ def lora_tab(
|
|||||||
gradio_dataset_balancing_tab()
|
gradio_dataset_balancing_tab()
|
||||||
|
|
||||||
button_run = gr.Button('Train model')
|
button_run = gr.Button('Train model')
|
||||||
|
|
||||||
|
settings_list = [
|
||||||
|
pretrained_model_name_or_path_input,
|
||||||
|
v2_input,
|
||||||
|
v_parameterization_input,
|
||||||
|
logging_dir_input,
|
||||||
|
train_data_dir_input,
|
||||||
|
reg_data_dir_input,
|
||||||
|
output_dir_input,
|
||||||
|
max_resolution_input,
|
||||||
|
learning_rate_input,
|
||||||
|
lr_scheduler_input,
|
||||||
|
lr_warmup_input,
|
||||||
|
train_batch_size_input,
|
||||||
|
epoch_input,
|
||||||
|
save_every_n_epochs_input,
|
||||||
|
mixed_precision_input,
|
||||||
|
save_precision_input,
|
||||||
|
seed_input,
|
||||||
|
num_cpu_threads_per_process_input,
|
||||||
|
cache_latent_input,
|
||||||
|
caption_extention_input,
|
||||||
|
enable_bucket_input,
|
||||||
|
gradient_checkpointing_input,
|
||||||
|
full_fp16_input,
|
||||||
|
no_token_padding_input,
|
||||||
|
stop_text_encoder_training_input,
|
||||||
|
use_8bit_adam_input,
|
||||||
|
xformers_input,
|
||||||
|
save_model_as_dropdown,
|
||||||
|
shuffle_caption,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
||||||
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
open_configuration,
|
open_configuration,
|
||||||
inputs=[
|
inputs=[config_file_name] + settings_list,
|
||||||
config_file_name,
|
outputs=[config_file_name] + settings_list,
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
logging_dir_input,
|
|
||||||
train_data_dir_input,
|
|
||||||
reg_data_dir_input,
|
|
||||||
output_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
cache_latent_input,
|
|
||||||
caption_extention_input,
|
|
||||||
enable_bucket_input,
|
|
||||||
gradient_checkpointing_input,
|
|
||||||
full_fp16_input,
|
|
||||||
no_token_padding_input,
|
|
||||||
stop_text_encoder_training_input,
|
|
||||||
use_8bit_adam_input,
|
|
||||||
xformers_input,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
shuffle_caption,
|
|
||||||
save_state,
|
|
||||||
resume,
|
|
||||||
prior_loss_weight,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
config_file_name,
|
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
logging_dir_input,
|
|
||||||
train_data_dir_input,
|
|
||||||
reg_data_dir_input,
|
|
||||||
output_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
cache_latent_input,
|
|
||||||
caption_extention_input,
|
|
||||||
enable_bucket_input,
|
|
||||||
gradient_checkpointing_input,
|
|
||||||
full_fp16_input,
|
|
||||||
no_token_padding_input,
|
|
||||||
stop_text_encoder_training_input,
|
|
||||||
use_8bit_adam_input,
|
|
||||||
xformers_input,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
shuffle_caption,
|
|
||||||
save_state,
|
|
||||||
resume,
|
|
||||||
prior_loss_weight,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
button_save_config.click(
|
button_save_config.click(
|
||||||
save_configuration,
|
save_configuration,
|
||||||
inputs=[
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||||
dummy_db_false,
|
|
||||||
config_file_name,
|
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
logging_dir_input,
|
|
||||||
train_data_dir_input,
|
|
||||||
reg_data_dir_input,
|
|
||||||
output_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
cache_latent_input,
|
|
||||||
caption_extention_input,
|
|
||||||
enable_bucket_input,
|
|
||||||
gradient_checkpointing_input,
|
|
||||||
full_fp16_input,
|
|
||||||
no_token_padding_input,
|
|
||||||
stop_text_encoder_training_input,
|
|
||||||
use_8bit_adam_input,
|
|
||||||
xformers_input,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
shuffle_caption,
|
|
||||||
save_state,
|
|
||||||
resume,
|
|
||||||
prior_loss_weight,
|
|
||||||
],
|
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
button_save_as_config.click(
|
button_save_as_config.click(
|
||||||
save_configuration,
|
save_configuration,
|
||||||
inputs=[
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||||
dummy_db_true,
|
|
||||||
config_file_name,
|
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
logging_dir_input,
|
|
||||||
train_data_dir_input,
|
|
||||||
reg_data_dir_input,
|
|
||||||
output_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
cache_latent_input,
|
|
||||||
caption_extention_input,
|
|
||||||
enable_bucket_input,
|
|
||||||
gradient_checkpointing_input,
|
|
||||||
full_fp16_input,
|
|
||||||
no_token_padding_input,
|
|
||||||
stop_text_encoder_training_input,
|
|
||||||
use_8bit_adam_input,
|
|
||||||
xformers_input,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
shuffle_caption,
|
|
||||||
save_state,
|
|
||||||
resume,
|
|
||||||
prior_loss_weight,
|
|
||||||
],
|
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
button_run.click(
|
button_run.click(
|
||||||
train_model,
|
train_model,
|
||||||
inputs=[
|
inputs=settings_list,
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
logging_dir_input,
|
|
||||||
train_data_dir_input,
|
|
||||||
reg_data_dir_input,
|
|
||||||
output_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
cache_latent_input,
|
|
||||||
caption_extention_input,
|
|
||||||
enable_bucket_input,
|
|
||||||
gradient_checkpointing_input,
|
|
||||||
full_fp16_input,
|
|
||||||
no_token_padding_input,
|
|
||||||
stop_text_encoder_training_input,
|
|
||||||
use_8bit_adam_input,
|
|
||||||
xformers_input,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
shuffle_caption,
|
|
||||||
save_state,
|
|
||||||
resume,
|
|
||||||
prior_loss_weight,
|
|
||||||
text_encoder_lr, unet_lr, network_train, network_dim
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
Loading…
Reference in New Issue
Block a user