WIP File Dialog Behavior
This commit is contained in:
parent
eef5becab8
commit
b6d3f10da7
@ -50,68 +50,68 @@ document_symbol = '\U0001F4C4' # 📄
|
|||||||
|
|
||||||
|
|
||||||
def save_configuration(
|
def save_configuration(
|
||||||
save_as,
|
save_as,
|
||||||
file_path,
|
file_path,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
train_data_dir,
|
train_data_dir,
|
||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
mixed_precision,
|
mixed_precision,
|
||||||
save_precision,
|
save_precision,
|
||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
no_token_padding,
|
no_token_padding,
|
||||||
stop_text_encoder_training,
|
stop_text_encoder_training,
|
||||||
# use_8bit_adam,
|
# use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as,
|
save_model_as,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
sample_every_n_steps,
|
sample_every_n_steps,
|
||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
additional_parameters,
|
additional_parameters,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -125,12 +125,12 @@ def save_configuration(
|
|||||||
file_path = get_saveasfile_path(file_path)
|
file_path = get_saveasfile_path(file_path)
|
||||||
else:
|
else:
|
||||||
print('Save...')
|
print('Save...')
|
||||||
if file_path == None or file_path == '':
|
if file_path is None or file_path == '':
|
||||||
file_path = get_saveasfile_path(file_path)
|
file_path = get_saveasfile_path(file_path)
|
||||||
|
|
||||||
# print(file_path)
|
# print(file_path)
|
||||||
|
|
||||||
if file_path == None or file_path == '':
|
if file_path is None or file_path == '':
|
||||||
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
|
|
||||||
# Return the values of the variables as a dictionary
|
# Return the values of the variables as a dictionary
|
||||||
@ -159,69 +159,73 @@ def save_configuration(
|
|||||||
|
|
||||||
|
|
||||||
def open_configuration(
|
def open_configuration(
|
||||||
ask_for_file,
|
ask_for_file,
|
||||||
file_path,
|
file_path,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
train_data_dir,
|
train_data_dir,
|
||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
mixed_precision,
|
mixed_precision,
|
||||||
save_precision,
|
save_precision,
|
||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
no_token_padding,
|
no_token_padding,
|
||||||
stop_text_encoder_training,
|
stop_text_encoder_training,
|
||||||
# use_8bit_adam,
|
# use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as,
|
save_model_as,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
sample_every_n_steps,
|
sample_every_n_steps,
|
||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
additional_parameters,
|
additional_parameters,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
):
|
):
|
||||||
|
print("open_configuration called")
|
||||||
|
print(f"locals length: {len(locals())}")
|
||||||
|
print(f"locals: {locals()}")
|
||||||
|
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
|
|
||||||
@ -229,9 +233,12 @@ def open_configuration(
|
|||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
if ask_for_file:
|
if ask_for_file and file_path is not None:
|
||||||
print(f"File path: {file_path}")
|
print(f"File path: {file_path}")
|
||||||
file_path = get_file_path(file_path, filedialog_type="json")
|
file_path, canceled = get_file_path(file_path=file_path, filedialog_type="json")
|
||||||
|
|
||||||
|
if canceled:
|
||||||
|
return (None,) + (None,) * (len(parameters) - 2)
|
||||||
|
|
||||||
if not file_path == '' and file_path is not None:
|
if not file_path == '' and file_path is not None:
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
@ -252,70 +259,72 @@ def open_configuration(
|
|||||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||||
if not key in ['ask_for_file', 'file_path']:
|
if not key in ['ask_for_file', 'file_path']:
|
||||||
values.append(my_data.get(key, value))
|
values.append(my_data.get(key, value))
|
||||||
|
# Print the number of returned values
|
||||||
|
print(f"Returning: {values}")
|
||||||
return tuple(values)
|
return tuple(values)
|
||||||
|
|
||||||
|
|
||||||
def train_model(
|
def train_model(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
train_data_dir,
|
train_data_dir,
|
||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
mixed_precision,
|
mixed_precision,
|
||||||
save_precision,
|
save_precision,
|
||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
no_token_padding,
|
no_token_padding,
|
||||||
stop_text_encoder_training_pct,
|
stop_text_encoder_training_pct,
|
||||||
# use_8bit_adam,
|
# use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as,
|
save_model_as,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
sample_every_n_steps,
|
sample_every_n_steps,
|
||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
additional_parameters,
|
additional_parameters,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
show_message_box('Source model information is missing')
|
show_message_box('Source model information is missing')
|
||||||
@ -346,7 +355,7 @@ def train_model(
|
|||||||
f
|
f
|
||||||
for f in os.listdir(train_data_dir)
|
for f in os.listdir(train_data_dir)
|
||||||
if os.path.isdir(os.path.join(train_data_dir, f))
|
if os.path.isdir(os.path.join(train_data_dir, f))
|
||||||
and not f.startswith('.')
|
and not f.startswith('.')
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check if subfolders are present. If not let the user know and return
|
# Check if subfolders are present. If not let the user know and return
|
||||||
@ -378,11 +387,11 @@ def train_model(
|
|||||||
[
|
[
|
||||||
f
|
f
|
||||||
for f, lower_f in (
|
for f, lower_f in (
|
||||||
(file, file.lower())
|
(file, file.lower())
|
||||||
for file in os.listdir(
|
for file in os.listdir(
|
||||||
os.path.join(train_data_dir, folder)
|
os.path.join(train_data_dir, folder)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -846,12 +855,13 @@ def dreambooth_tab(
|
|||||||
)
|
)
|
||||||
|
|
||||||
button_load_config.click(
|
button_load_config.click(
|
||||||
lambda *args, **kwargs: open_configuration(*args, **kwargs),
|
lambda *args, **kwargs: (print("Lambda called"), open_configuration(*args, **kwargs)),
|
||||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||||
outputs=[config_file_name] + settings_list,
|
outputs=[config_file_name] + settings_list,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
# Print the number of expected outputs
|
||||||
|
print(f"Number of expected outputs: {len([config_file_name] + settings_list)}")
|
||||||
button_save_config.click(
|
button_save_config.click(
|
||||||
save_configuration,
|
save_configuration,
|
||||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||||
|
@ -151,26 +151,8 @@ def update_my_data(my_data):
|
|||||||
# # If no extension files were found, return False
|
# # If no extension files were found, return False
|
||||||
# return False
|
# return False
|
||||||
|
|
||||||
# def get_file_path_gradio_wrapper(file_path, filedialog_type="all"):
|
|
||||||
# file_extension = os.path.splitext(file_path)[-1].lower()
|
|
||||||
#
|
|
||||||
# filetype_filters = {
|
|
||||||
# 'db': ['.db'],
|
|
||||||
# 'json': ['.json'],
|
|
||||||
# 'lora': ['.pt', '.ckpt', '.safetensors'],
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# # Find the appropriate filedialog_type based on the file extension
|
|
||||||
# filedialog_type = 'all'
|
|
||||||
# for key, extensions in filetype_filters.items():
|
|
||||||
# if file_extension in extensions:
|
|
||||||
# filedialog_type = key
|
|
||||||
# break
|
|
||||||
#
|
|
||||||
# return get_file_path(file_path, filedialog_type)
|
|
||||||
|
|
||||||
|
def get_file_path(file_path, initial_dir=None, initial_file=None, filedialog_type="lora"):
|
||||||
def get_file_path(file_path='', filedialog_type="lora"):
|
|
||||||
file_extension = os.path.splitext(file_path)[-1].lower()
|
file_extension = os.path.splitext(file_path)[-1].lower()
|
||||||
|
|
||||||
# Find the appropriate filedialog_type based on the file extension
|
# Find the appropriate filedialog_type based on the file extension
|
||||||
@ -181,16 +163,10 @@ def get_file_path(file_path='', filedialog_type="lora"):
|
|||||||
|
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
|
|
||||||
print(f"File type: {filedialog_type}")
|
|
||||||
initial_dir, initial_file = os.path.split(file_path)
|
initial_dir, initial_file = os.path.split(file_path)
|
||||||
file_path = open_file_dialog(initial_dir, initial_file, file_types=filedialog_type)
|
result = open_file_dialog(initial_dir=initial_dir, initial_file=initial_file, file_types=filedialog_type)
|
||||||
|
file_path, canceled = result[:2]
|
||||||
# If no file is selected, use the current file path
|
return file_path, canceled
|
||||||
if not file_path:
|
|
||||||
file_path = current_file_path
|
|
||||||
current_file_path = file_path
|
|
||||||
|
|
||||||
return file_path
|
|
||||||
|
|
||||||
|
|
||||||
def get_any_file_path(file_path=''):
|
def get_any_file_path(file_path=''):
|
||||||
|
@ -13,7 +13,6 @@ class TkGui:
|
|||||||
self.file_types = None
|
self.file_types = None
|
||||||
|
|
||||||
def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"):
|
def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"):
|
||||||
print(f"File types: {self.file_types}")
|
|
||||||
with tk_context():
|
with tk_context():
|
||||||
self.file_types = file_types
|
self.file_types = file_types
|
||||||
if self.file_types in CommonUtilities.file_filters:
|
if self.file_types in CommonUtilities.file_filters:
|
||||||
@ -22,9 +21,14 @@ class TkGui:
|
|||||||
filters = CommonUtilities.file_filters["all"]
|
filters = CommonUtilities.file_filters["all"]
|
||||||
|
|
||||||
if self.file_types == "directory":
|
if self.file_types == "directory":
|
||||||
return filedialog.askdirectory(initialdir=initial_dir)
|
result = filedialog.askdirectory(initialdir=initial_dir)
|
||||||
else:
|
else:
|
||||||
return filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
|
result = filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters)
|
||||||
|
|
||||||
|
# Return a tuple (file_path, canceled)
|
||||||
|
# file_path: the selected file path or an empty string if no file is selected
|
||||||
|
# canceled: True if the user pressed the cancel button, False otherwise
|
||||||
|
return result, result == ""
|
||||||
|
|
||||||
def save_file_dialog(self, initial_dir, initial_file, file_types="all"):
|
def save_file_dialog(self, initial_dir, initial_file, file_types="all"):
|
||||||
self.file_types = file_types
|
self.file_types = file_types
|
||||||
|
Loading…
Reference in New Issue
Block a user