From b6d3f10da787f7ea4bb6717743fff8f23ef573c8 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Sat, 1 Apr 2023 17:33:41 -0700 Subject: [PATCH] WIP File Dialog Behavior --- dreambooth_gui.py | 402 ++++++++++++++++---------------- library/common_gui_functions.py | 32 +-- library/gui_subprocesses.py | 10 +- 3 files changed, 217 insertions(+), 227 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index e54704f..c2185c1 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -50,68 +50,68 @@ document_symbol = '\U0001F4C4' # 📄 def save_configuration( - save_as, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -125,12 +125,12 @@ def save_configuration( file_path = get_saveasfile_path(file_path) else: print('Save...') - if file_path == None or file_path == '': + if file_path is None or file_path == '': file_path = get_saveasfile_path(file_path) # print(file_path) - if file_path == None or file_path == '': + if file_path is None or file_path == '': return original_file_path # In case a file_path was provided and the user decide to cancel the open action # Return the values of the variables as a dictionary @@ -159,69 +159,73 @@ def save_configuration( def open_configuration( - ask_for_file, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): + print("open_configuration called") + print(f"locals length: {len(locals())}") + print(f"locals: {locals()}") + # Get list of function parameters and values parameters = list(locals().items()) @@ -229,9 +233,12 @@ def open_configuration( original_file_path = file_path - if ask_for_file: + if ask_for_file and file_path is not None: print(f"File path: {file_path}") - file_path = get_file_path(file_path, filedialog_type="json") + file_path, canceled = get_file_path(file_path=file_path, filedialog_type="json") + + if canceled: + return (None,) + (None,) * (len(parameters) - 2) if not file_path == '' and file_path is not None: with open(file_path, 'r') as f: @@ -252,70 +259,72 @@ def open_configuration( # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) + # Print the number of returned values + print(f"Returning: {values}") return tuple(values) def train_model( - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training_pct, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, # Keep this. Yes, it is unused here but required given the common list used - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training_pct, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, # Keep this. Yes, it is unused here but required given the common list used + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): if pretrained_model_name_or_path == '': show_message_box('Source model information is missing') @@ -346,7 +355,7 @@ def train_model( f for f in os.listdir(train_data_dir) if os.path.isdir(os.path.join(train_data_dir, f)) - and not f.startswith('.') + and not f.startswith('.') ] # Check if subfolders are present. If not let the user know and return @@ -378,11 +387,11 @@ def train_model( [ f for f, lower_f in ( - (file, file.lower()) - for file in os.listdir( - os.path.join(train_data_dir, folder) - ) - ) + (file, file.lower()) + for file in os.listdir( + os.path.join(train_data_dir, folder) + ) + ) if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) ] ) @@ -846,12 +855,13 @@ def dreambooth_tab( ) button_load_config.click( - lambda *args, **kwargs: open_configuration(*args, **kwargs), + lambda *args, **kwargs: (print("Lambda called"), open_configuration(*args, **kwargs)), inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) - + # Print the number of expected outputs + print(f"Number of expected outputs: {len([config_file_name] + settings_list)}") button_save_config.click( save_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, diff --git a/library/common_gui_functions.py b/library/common_gui_functions.py index 77f4923..73eb18d 100644 --- a/library/common_gui_functions.py +++ b/library/common_gui_functions.py @@ -151,26 +151,8 @@ def update_my_data(my_data): # # If no extension files were found, return False # return False -# def get_file_path_gradio_wrapper(file_path, filedialog_type="all"): -# file_extension = os.path.splitext(file_path)[-1].lower() -# -# filetype_filters = { -# 'db': ['.db'], -# 'json': ['.json'], -# 'lora': ['.pt', '.ckpt', '.safetensors'], -# } -# -# # Find the appropriate filedialog_type based on the file extension -# filedialog_type = 'all' -# for key, extensions in filetype_filters.items(): -# if file_extension in extensions: -# filedialog_type = key -# break -# -# return get_file_path(file_path, filedialog_type) - -def get_file_path(file_path='', filedialog_type="lora"): +def get_file_path(file_path, initial_dir=None, initial_file=None, filedialog_type="lora"): file_extension = os.path.splitext(file_path)[-1].lower() # Find the appropriate filedialog_type based on the file extension @@ -181,16 +163,10 @@ def get_file_path(file_path='', filedialog_type="lora"): current_file_path = file_path - print(f"File type: {filedialog_type}") initial_dir, initial_file = os.path.split(file_path) - file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type) - - # If no file is selected, use the current file path - if not file_path: - file_path = current_file_path - current_file_path = file_path - - return file_path + result = open_file_dialog(initial_dir=initial_dir, initial_file=initial_file, file_types=filedialog_type) + file_path, canceled = result[:2] + return file_path, canceled def get_any_file_path(file_path=''): diff --git a/library/gui_subprocesses.py b/library/gui_subprocesses.py index 2cbdaf2..2e45b8e 100644 --- a/library/gui_subprocesses.py +++ b/library/gui_subprocesses.py @@ -13,7 +13,6 @@ class TkGui: self.file_types = None def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"): - print(f"File types: {self.file_types}") with tk_context(): self.file_types = file_types if self.file_types in CommonUtilities.file_filters: @@ -22,9 +21,14 @@ class TkGui: filters = CommonUtilities.file_filters["all"] if self.file_types == "directory": - return filedialog.askdirectory(initialdir=initial_dir) + result = filedialog.askdirectory(initialdir=initial_dir) else: - return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) + result = filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) + + # Return a tuple (file_path, canceled) + # file_path: the selected file path or an empty string if no file is selected + # canceled: True if the user pressed the cancel button, False otherwise + return result, result == "" def save_file_dialog(self, initial_dir, initial_file, file_types="all"): self.file_types = file_types