Implement open and save config for LoRA
This commit is contained in:
parent
0f42ab78c4
commit
b44f075f60
@ -20,6 +20,21 @@ def get_file_path(file_path='', defaultextension='.json'):
|
||||
|
||||
return file_path
|
||||
|
||||
def get_any_file_path(file_path=''):
|
||||
current_file_path = file_path
|
||||
# print(f'current file path: {current_file_path}')
|
||||
|
||||
root = Tk()
|
||||
root.wm_attributes('-topmost', 1)
|
||||
root.withdraw()
|
||||
file_path = filedialog.askopenfilename()
|
||||
root.destroy()
|
||||
|
||||
if file_path == '':
|
||||
file_path = current_file_path
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def remove_doublequote(file_path):
|
||||
if file_path != None:
|
||||
|
247
lora_gui.py
247
lora_gui.py
@ -15,6 +15,7 @@ from library.common_gui import (
|
||||
get_folder_path,
|
||||
remove_doublequote,
|
||||
get_file_path,
|
||||
get_any_file_path,
|
||||
get_saveasfile_path,
|
||||
)
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
@ -64,7 +65,7 @@ def save_configuration(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
||||
):
|
||||
original_file_path = file_path
|
||||
|
||||
@ -117,6 +118,10 @@ def save_configuration(
|
||||
'save_state': save_state,
|
||||
'resume': resume,
|
||||
'prior_loss_weight': prior_loss_weight,
|
||||
'text_encoder_lr': text_encoder_lr,
|
||||
'unet_lr': unet_lr,
|
||||
'network_train': network_train,
|
||||
'network_dim': network_dim
|
||||
}
|
||||
|
||||
# Save the data to the selected file
|
||||
@ -159,7 +164,7 @@ def open_configuration(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
||||
):
|
||||
|
||||
original_file_path = file_path
|
||||
@ -213,6 +218,10 @@ def open_configuration(
|
||||
my_data.get('save_state', save_state),
|
||||
my_data.get('resume', resume),
|
||||
my_data.get('prior_loss_weight', prior_loss_weight),
|
||||
my_data.get('text_encoder_lr', text_encoder_lr),
|
||||
my_data.get('unet_lr', unet_lr),
|
||||
my_data.get('network_train', network_train),
|
||||
my_data.get('network_dim', network_dim),
|
||||
)
|
||||
|
||||
|
||||
@ -548,10 +557,10 @@ def lora_tab(
|
||||
label='Pretrained model name or path',
|
||||
placeholder='enter the path to custom model or name of pretrained model',
|
||||
)
|
||||
pretrained_model_name_or_path_fille = gr.Button(
|
||||
pretrained_model_name_or_path_file = gr.Button(
|
||||
document_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
pretrained_model_name_or_path_fille.click(
|
||||
pretrained_model_name_or_path_file.click(
|
||||
get_file_path,
|
||||
inputs=[pretrained_model_name_or_path_input],
|
||||
outputs=pretrained_model_name_or_path_input,
|
||||
@ -586,6 +595,19 @@ def lora_tab(
|
||||
],
|
||||
value='same as source model',
|
||||
)
|
||||
with gr.Row():
|
||||
lora_network_weights = gr.Textbox(
|
||||
label='LoRA network weights',
|
||||
placeholder='{Optional) Path to existing LoRA network weights to resume training}',
|
||||
)
|
||||
lora_network_weights_file = gr.Button(
|
||||
document_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
lora_network_weights_file.click(
|
||||
get_any_file_path,
|
||||
inputs=[lora_network_weights],
|
||||
outputs=lora_network_weights,
|
||||
)
|
||||
with gr.Row():
|
||||
v2_input = gr.Checkbox(label='v2', value=True)
|
||||
v_parameterization_input = gr.Checkbox(
|
||||
@ -812,200 +834,63 @@ def lora_tab(
|
||||
gradio_dataset_balancing_tab()
|
||||
|
||||
button_run = gr.Button('Train model')
|
||||
|
||||
settings_list = [
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
logging_dir_input,
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
cache_latent_input,
|
||||
caption_extention_input,
|
||||
enable_bucket_input,
|
||||
gradient_checkpointing_input,
|
||||
full_fp16_input,
|
||||
no_token_padding_input,
|
||||
stop_text_encoder_training_input,
|
||||
use_8bit_adam_input,
|
||||
xformers_input,
|
||||
save_model_as_dropdown,
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
open_configuration,
|
||||
inputs=[
|
||||
config_file_name,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
logging_dir_input,
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
cache_latent_input,
|
||||
caption_extention_input,
|
||||
enable_bucket_input,
|
||||
gradient_checkpointing_input,
|
||||
full_fp16_input,
|
||||
no_token_padding_input,
|
||||
stop_text_encoder_training_input,
|
||||
use_8bit_adam_input,
|
||||
xformers_input,
|
||||
save_model_as_dropdown,
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
],
|
||||
outputs=[
|
||||
config_file_name,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
logging_dir_input,
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
cache_latent_input,
|
||||
caption_extention_input,
|
||||
enable_bucket_input,
|
||||
gradient_checkpointing_input,
|
||||
full_fp16_input,
|
||||
no_token_padding_input,
|
||||
stop_text_encoder_training_input,
|
||||
use_8bit_adam_input,
|
||||
xformers_input,
|
||||
save_model_as_dropdown,
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
],
|
||||
inputs=[config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
)
|
||||
|
||||
button_save_config.click(
|
||||
save_configuration,
|
||||
inputs=[
|
||||
dummy_db_false,
|
||||
config_file_name,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
logging_dir_input,
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
cache_latent_input,
|
||||
caption_extention_input,
|
||||
enable_bucket_input,
|
||||
gradient_checkpointing_input,
|
||||
full_fp16_input,
|
||||
no_token_padding_input,
|
||||
stop_text_encoder_training_input,
|
||||
use_8bit_adam_input,
|
||||
xformers_input,
|
||||
save_model_as_dropdown,
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
],
|
||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
|
||||
button_save_as_config.click(
|
||||
save_configuration,
|
||||
inputs=[
|
||||
dummy_db_true,
|
||||
config_file_name,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
logging_dir_input,
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
cache_latent_input,
|
||||
caption_extention_input,
|
||||
enable_bucket_input,
|
||||
gradient_checkpointing_input,
|
||||
full_fp16_input,
|
||||
no_token_padding_input,
|
||||
stop_text_encoder_training_input,
|
||||
use_8bit_adam_input,
|
||||
xformers_input,
|
||||
save_model_as_dropdown,
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
],
|
||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
|
||||
button_run.click(
|
||||
train_model,
|
||||
inputs=[
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
logging_dir_input,
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
cache_latent_input,
|
||||
caption_extention_input,
|
||||
enable_bucket_input,
|
||||
gradient_checkpointing_input,
|
||||
full_fp16_input,
|
||||
no_token_padding_input,
|
||||
stop_text_encoder_training_input,
|
||||
use_8bit_adam_input,
|
||||
xformers_input,
|
||||
save_model_as_dropdown,
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
text_encoder_lr, unet_lr, network_train, network_dim
|
||||
],
|
||||
inputs=settings_list,
|
||||
)
|
||||
|
||||
return (
|
||||
|
Loading…
Reference in New Issue
Block a user