WIP File Dialog Behavior

This commit is contained in:
JSTayco 2023-04-01 17:33:41 -07:00
parent eef5becab8
commit b6d3f10da7
3 changed files with 217 additions and 227 deletions

View File

@ -50,68 +50,68 @@ document_symbol = '\U0001F4C4' # 📄
def save_configuration( def save_configuration(
save_as, save_as,
file_path, file_path,
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
v_parameterization, v_parameterization,
logging_dir, logging_dir,
train_data_dir, train_data_dir,
reg_data_dir, reg_data_dir,
output_dir, output_dir,
max_resolution, max_resolution,
learning_rate, learning_rate,
lr_scheduler, lr_scheduler,
lr_warmup, lr_warmup,
train_batch_size, train_batch_size,
epoch, epoch,
save_every_n_epochs, save_every_n_epochs,
mixed_precision, mixed_precision,
save_precision, save_precision,
seed, seed,
num_cpu_threads_per_process, num_cpu_threads_per_process,
cache_latents, cache_latents,
caption_extension, caption_extension,
enable_bucket, enable_bucket,
gradient_checkpointing, gradient_checkpointing,
full_fp16, full_fp16,
no_token_padding, no_token_padding,
stop_text_encoder_training, stop_text_encoder_training,
# use_8bit_adam, # use_8bit_adam,
xformers, xformers,
save_model_as, save_model_as,
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight,
color_aug, color_aug,
flip_aug, flip_aug,
clip_skip, clip_skip,
vae, vae,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, model_list,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args, optimizer_args,
noise_offset, noise_offset,
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -125,12 +125,12 @@ def save_configuration(
file_path = get_saveasfile_path(file_path) file_path = get_saveasfile_path(file_path)
else: else:
print('Save...') print('Save...')
if file_path == None or file_path == '': if file_path is None or file_path == '':
file_path = get_saveasfile_path(file_path) file_path = get_saveasfile_path(file_path)
# print(file_path) # print(file_path)
if file_path == None or file_path == '': if file_path is None or file_path == '':
return original_file_path # In case a file_path was provided and the user decide to cancel the open action return original_file_path # In case a file_path was provided and the user decide to cancel the open action
# Return the values of the variables as a dictionary # Return the values of the variables as a dictionary
@ -159,69 +159,73 @@ def save_configuration(
def open_configuration( def open_configuration(
ask_for_file, ask_for_file,
file_path, file_path,
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
v_parameterization, v_parameterization,
logging_dir, logging_dir,
train_data_dir, train_data_dir,
reg_data_dir, reg_data_dir,
output_dir, output_dir,
max_resolution, max_resolution,
learning_rate, learning_rate,
lr_scheduler, lr_scheduler,
lr_warmup, lr_warmup,
train_batch_size, train_batch_size,
epoch, epoch,
save_every_n_epochs, save_every_n_epochs,
mixed_precision, mixed_precision,
save_precision, save_precision,
seed, seed,
num_cpu_threads_per_process, num_cpu_threads_per_process,
cache_latents, cache_latents,
caption_extension, caption_extension,
enable_bucket, enable_bucket,
gradient_checkpointing, gradient_checkpointing,
full_fp16, full_fp16,
no_token_padding, no_token_padding,
stop_text_encoder_training, stop_text_encoder_training,
# use_8bit_adam, # use_8bit_adam,
xformers, xformers,
save_model_as, save_model_as,
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight,
color_aug, color_aug,
flip_aug, flip_aug,
clip_skip, clip_skip,
vae, vae,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, model_list,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args, optimizer_args,
noise_offset, noise_offset,
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
): ):
print("open_configuration called")
print(f"locals length: {len(locals())}")
print(f"locals: {locals()}")
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -229,9 +233,12 @@ def open_configuration(
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file and file_path is not None:
print(f"File path: {file_path}") print(f"File path: {file_path}")
file_path = get_file_path(file_path, filedialog_type="json") file_path, canceled = get_file_path(file_path=file_path, filedialog_type="json")
if canceled:
return (None,) + (None,) * (len(parameters) - 2)
if not file_path == '' and file_path is not None: if not file_path == '' and file_path is not None:
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
@ -252,70 +259,72 @@ def open_configuration(
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['ask_for_file', 'file_path']: if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value)) values.append(my_data.get(key, value))
# Print the number of returned values
print(f"Returning: {values}")
return tuple(values) return tuple(values)
def train_model( def train_model(
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
v_parameterization, v_parameterization,
logging_dir, logging_dir,
train_data_dir, train_data_dir,
reg_data_dir, reg_data_dir,
output_dir, output_dir,
max_resolution, max_resolution,
learning_rate, learning_rate,
lr_scheduler, lr_scheduler,
lr_warmup, lr_warmup,
train_batch_size, train_batch_size,
epoch, epoch,
save_every_n_epochs, save_every_n_epochs,
mixed_precision, mixed_precision,
save_precision, save_precision,
seed, seed,
num_cpu_threads_per_process, num_cpu_threads_per_process,
cache_latents, cache_latents,
caption_extension, caption_extension,
enable_bucket, enable_bucket,
gradient_checkpointing, gradient_checkpointing,
full_fp16, full_fp16,
no_token_padding, no_token_padding,
stop_text_encoder_training_pct, stop_text_encoder_training_pct,
# use_8bit_adam, # use_8bit_adam,
xformers, xformers,
save_model_as, save_model_as,
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, prior_loss_weight,
color_aug, color_aug,
flip_aug, flip_aug,
clip_skip, clip_skip,
vae, vae,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args, optimizer_args,
noise_offset, noise_offset,
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
show_message_box('Source model information is missing') show_message_box('Source model information is missing')
@ -346,7 +355,7 @@ def train_model(
f f
for f in os.listdir(train_data_dir) for f in os.listdir(train_data_dir)
if os.path.isdir(os.path.join(train_data_dir, f)) if os.path.isdir(os.path.join(train_data_dir, f))
and not f.startswith('.') and not f.startswith('.')
] ]
# Check if subfolders are present. If not let the user know and return # Check if subfolders are present. If not let the user know and return
@ -378,11 +387,11 @@ def train_model(
[ [
f f
for f, lower_f in ( for f, lower_f in (
(file, file.lower()) (file, file.lower())
for file in os.listdir( for file in os.listdir(
os.path.join(train_data_dir, folder) os.path.join(train_data_dir, folder)
) )
) )
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
] ]
) )
@ -846,12 +855,13 @@ def dreambooth_tab(
) )
button_load_config.click( button_load_config.click(
lambda *args, **kwargs: open_configuration(*args, **kwargs), lambda *args, **kwargs: (print("Lambda called"), open_configuration(*args, **kwargs)),
inputs=[dummy_db_true, config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )
# Print the number of expected outputs
print(f"Number of expected outputs: {len([config_file_name] + settings_list)}")
button_save_config.click( button_save_config.click(
save_configuration, save_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list, inputs=[dummy_db_false, config_file_name] + settings_list,

View File

@ -151,26 +151,8 @@ def update_my_data(my_data):
# # If no extension files were found, return False # # If no extension files were found, return False
# return False # return False
# def get_file_path_gradio_wrapper(file_path, filedialog_type="all"):
# file_extension = os.path.splitext(file_path)[-1].lower()
#
# filetype_filters = {
# 'db': ['.db'],
# 'json': ['.json'],
# 'lora': ['.pt', '.ckpt', '.safetensors'],
# }
#
# # Find the appropriate filedialog_type based on the file extension
# filedialog_type = 'all'
# for key, extensions in filetype_filters.items():
# if file_extension in extensions:
# filedialog_type = key
# break
#
# return get_file_path(file_path, filedialog_type)
def get_file_path(file_path, initial_dir=None, initial_file=None, filedialog_type="lora"):
def get_file_path(file_path='', filedialog_type="lora"):
file_extension = os.path.splitext(file_path)[-1].lower() file_extension = os.path.splitext(file_path)[-1].lower()
# Find the appropriate filedialog_type based on the file extension # Find the appropriate filedialog_type based on the file extension
@ -181,16 +163,10 @@ def get_file_path(file_path='', filedialog_type="lora"):
current_file_path = file_path current_file_path = file_path
print(f"File type: {filedialog_type}")
initial_dir, initial_file = os.path.split(file_path) initial_dir, initial_file = os.path.split(file_path)
file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type) result = open_file_dialog(initial_dir=initial_dir, initial_file=initial_file, file_types=filedialog_type)
file_path, canceled = result[:2]
# If no file is selected, use the current file path return file_path, canceled
if not file_path:
file_path = current_file_path
current_file_path = file_path
return file_path
def get_any_file_path(file_path=''): def get_any_file_path(file_path=''):

View File

@ -13,7 +13,6 @@ class TkGui:
self.file_types = None self.file_types = None
def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"): def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"):
print(f"File types: {self.file_types}")
with tk_context(): with tk_context():
self.file_types = file_types self.file_types = file_types
if self.file_types in CommonUtilities.file_filters: if self.file_types in CommonUtilities.file_filters:
@ -22,9 +21,14 @@ class TkGui:
filters = CommonUtilities.file_filters["all"] filters = CommonUtilities.file_filters["all"]
if self.file_types == "directory": if self.file_types == "directory":
return filedialog.askdirectory(initialdir=initial_dir) result = filedialog.askdirectory(initialdir=initial_dir)
else: else:
return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) result = filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
# Return a tuple (file_path, canceled)
# file_path: the selected file path or an empty string if no file is selected
# canceled: True if the user pressed the cancel button, False otherwise
return result, result == ""
def save_file_dialog(self, initial_dir, initial_file, file_types="all"): def save_file_dialog(self, initial_dir, initial_file, file_types="all"):
self.file_types = file_types self.file_types = file_types