Implement open and save config for LoRA

This commit is contained in:
bmaltais 2022-12-29 14:00:02 -05:00
parent 0f42ab78c4
commit b44f075f60
3 changed files with 82 additions and 182 deletions

View File

@ -20,6 +20,21 @@ def get_file_path(file_path='', defaultextension='.json'):
return file_path return file_path
def get_any_file_path(file_path=''):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
file_path = filedialog.askopenfilename()
root.destroy()
if file_path == '':
file_path = current_file_path
return file_path
def remove_doublequote(file_path): def remove_doublequote(file_path):
if file_path != None: if file_path != None:

View File

@ -15,6 +15,7 @@ from library.common_gui import (
get_folder_path, get_folder_path,
remove_doublequote, remove_doublequote,
get_file_path, get_file_path,
get_any_file_path,
get_saveasfile_path, get_saveasfile_path,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
@ -64,7 +65,7 @@ def save_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
): ):
original_file_path = file_path original_file_path = file_path
@ -117,6 +118,10 @@ def save_configuration(
'save_state': save_state, 'save_state': save_state,
'resume': resume, 'resume': resume,
'prior_loss_weight': prior_loss_weight, 'prior_loss_weight': prior_loss_weight,
'text_encoder_lr': text_encoder_lr,
'unet_lr': unet_lr,
'network_train': network_train,
'network_dim': network_dim
} }
# Save the data to the selected file # Save the data to the selected file
@ -159,7 +164,7 @@ def open_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
): ):
original_file_path = file_path original_file_path = file_path
@ -213,6 +218,10 @@ def open_configuration(
my_data.get('save_state', save_state), my_data.get('save_state', save_state),
my_data.get('resume', resume), my_data.get('resume', resume),
my_data.get('prior_loss_weight', prior_loss_weight), my_data.get('prior_loss_weight', prior_loss_weight),
my_data.get('text_encoder_lr', text_encoder_lr),
my_data.get('unet_lr', unet_lr),
my_data.get('network_train', network_train),
my_data.get('network_dim', network_dim),
) )
@ -548,10 +557,10 @@ def lora_tab(
label='Pretrained model name or path', label='Pretrained model name or path',
placeholder='enter the path to custom model or name of pretrained model', placeholder='enter the path to custom model or name of pretrained model',
) )
pretrained_model_name_or_path_fille = gr.Button( pretrained_model_name_or_path_file = gr.Button(
document_symbol, elem_id='open_folder_small' document_symbol, elem_id='open_folder_small'
) )
pretrained_model_name_or_path_fille.click( pretrained_model_name_or_path_file.click(
get_file_path, get_file_path,
inputs=[pretrained_model_name_or_path_input], inputs=[pretrained_model_name_or_path_input],
outputs=pretrained_model_name_or_path_input, outputs=pretrained_model_name_or_path_input,
@ -586,6 +595,19 @@ def lora_tab(
], ],
value='same as source model', value='same as source model',
) )
with gr.Row():
lora_network_weights = gr.Textbox(
label='LoRA network weights',
placeholder='{Optional) Path to existing LoRA network weights to resume training}',
)
lora_network_weights_file = gr.Button(
document_symbol, elem_id='open_folder_small'
)
lora_network_weights_file.click(
get_any_file_path,
inputs=[lora_network_weights],
outputs=lora_network_weights,
)
with gr.Row(): with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True) v2_input = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox( v_parameterization_input = gr.Checkbox(
@ -813,199 +835,62 @@ def lora_tab(
button_run = gr.Button('Train model') button_run = gr.Button('Train model')
settings_list = [
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
logging_dir_input,
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
cache_latent_input,
caption_extention_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
]
button_open_config.click( button_open_config.click(
open_configuration, open_configuration,
inputs=[ inputs=[config_file_name] + settings_list,
config_file_name, outputs=[config_file_name] + settings_list,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
logging_dir_input,
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
cache_latent_input,
caption_extention_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
outputs=[
config_file_name,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
logging_dir_input,
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
cache_latent_input,
caption_extention_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
) )
button_save_config.click( button_save_config.click(
save_configuration, save_configuration,
inputs=[ inputs=[dummy_db_false, config_file_name] + settings_list,
dummy_db_false,
config_file_name,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
logging_dir_input,
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
cache_latent_input,
caption_extention_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
outputs=[config_file_name], outputs=[config_file_name],
) )
button_save_as_config.click( button_save_as_config.click(
save_configuration, save_configuration,
inputs=[ inputs=[dummy_db_true, config_file_name] + settings_list,
dummy_db_true,
config_file_name,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
logging_dir_input,
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
cache_latent_input,
caption_extention_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
outputs=[config_file_name], outputs=[config_file_name],
) )
button_run.click( button_run.click(
train_model, train_model,
inputs=[ inputs=settings_list,
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
logging_dir_input,
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
max_resolution_input,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
cache_latent_input,
caption_extention_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
text_encoder_lr, unet_lr, network_train, network_dim
],
) )
return ( return (

View File

@ -1,3 +1,3 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
setup(name = "library", version="1.0.0", packages = find_packages()) setup(name = "library", version="1.0.1", packages = find_packages())