Add support to load a config without opening the UI to get the file name

This commit is contained in:
bmaltais 2023-03-10 20:05:38 -05:00
parent d1962d7240
commit a65555ea67
5 changed files with 73 additions and 18 deletions

View File

@ -152,6 +152,7 @@ def save_configuration(
def open_configuration(
ask_for_file,
file_path,
pretrained_model_name_or_path,
v2,
@ -213,9 +214,13 @@ def open_configuration(
):
# Get list of function parameters and values
parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path
file_path = get_file_path(file_path)
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None:
# load variables from JSON file
@ -231,7 +236,7 @@ def open_configuration(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']:
if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value))
return tuple(values)
@ -506,6 +511,7 @@ def dreambooth_tab(
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
) = gradio_config()
(
@ -775,7 +781,14 @@ def dreambooth_tab(
button_open_config.click(
open_configuration,
inputs=[config_file_name] + settings_list,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)

View File

@ -149,7 +149,8 @@ def save_configuration(
return file_path
def open_config_file(
def open_configuration(
ask_for_file,
file_path,
pretrained_model_name_or_path,
v2,
@ -217,9 +218,13 @@ def open_config_file(
):
# Get list of function parameters and values
parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path
file_path = get_file_path(file_path)
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None:
# load variables from JSON file
@ -235,7 +240,7 @@ def open_config_file(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']:
if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value))
return tuple(values)
@ -492,8 +497,8 @@ def remove_doublequote(file_path):
def finetune_tab():
dummy_ft_true = gr.Label(value=True, visible=False)
dummy_ft_false = gr.Label(value=False, visible=False)
dummy_db_true = gr.Label(value=True, visible=False)
dummy_db_false = gr.Label(value=False, visible=False)
gr.Markdown('Train a custom model using kohya finetune python code...')
(
@ -501,6 +506,7 @@ def finetune_tab():
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
) = gradio_config()
(
@ -770,22 +776,29 @@ def finetune_tab():
button_run.click(train_model, inputs=settings_list)
button_open_config.click(
open_config_file,
inputs=[config_file_name] + settings_list,
open_configuration,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_save_config.click(
save_configuration,
inputs=[dummy_ft_false, config_file_name] + settings_list,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name],
show_progress=False,
)
button_save_as_config.click(
save_configuration,
inputs=[dummy_ft_true, config_file_name] + settings_list,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name],
show_progress=False,
)

View File

@ -405,11 +405,14 @@ def gradio_config():
placeholder="type the configuration file path or use the 'Open' button above to select it...",
interactive=True,
)
button_load_config = gr.Button('Load 💾', elem_id='open_folder')
config_file_name.change(remove_doublequote, inputs=[config_file_name], outputs=[config_file_name])
return (
button_open_config,
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
)

View File

@ -168,6 +168,7 @@ def save_configuration(
def open_configuration(
ask_for_file,
file_path,
pretrained_model_name_or_path,
v2,
@ -239,9 +240,13 @@ def open_configuration(
):
# Get list of function parameters and values
parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path
file_path = get_file_path(file_path)
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None:
# load variables from JSON file
@ -257,7 +262,7 @@ def open_configuration(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']:
if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value))
# This next section is about making the LoCon parameters visible if LoRA_type = 'Standard'
@ -610,6 +615,7 @@ def lora_tab(
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
) = gradio_config()
(
@ -974,7 +980,14 @@ def lora_tab(
button_open_config.click(
open_configuration,
inputs=[config_file_name] + settings_list,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False,
)
button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False,
)

View File

@ -158,6 +158,7 @@ def save_configuration(
def open_configuration(
ask_for_file,
file_path,
pretrained_model_name_or_path,
v2,
@ -225,9 +226,13 @@ def open_configuration(
):
# Get list of function parameters and values
parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path
file_path = get_file_path(file_path)
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None:
# load variables from JSON file
@ -243,7 +248,7 @@ def open_configuration(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']:
if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value))
return tuple(values)
@ -548,6 +553,7 @@ def ti_tab(
button_save_config,
button_save_as_config,
config_file_name,
button_load_config,
) = gradio_config()
(
@ -865,7 +871,14 @@ def ti_tab(
button_open_config.click(
open_configuration,
inputs=[config_file_name] + settings_list,
inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)