Add support to load a config without opening the UI to get the file name

This commit is contained in:
bmaltais 2023-03-10 20:05:38 -05:00
parent d1962d7240
commit a65555ea67
5 changed files with 73 additions and 18 deletions

View File

@ -152,6 +152,7 @@ def save_configuration(
def open_configuration( def open_configuration(
ask_for_file,
file_path, file_path,
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
@ -214,8 +215,12 @@ def open_configuration(
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path original_file_path = file_path
file_path = get_file_path(file_path)
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None: if not file_path == '' and not file_path == None:
# load variables from JSON file # load variables from JSON file
@ -231,7 +236,7 @@ def open_configuration(
values = [file_path] values = [file_path]
for key, value in parameters: for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']: if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value)) values.append(my_data.get(key, value))
return tuple(values) return tuple(values)
@ -506,6 +511,7 @@ def dreambooth_tab(
button_save_config, button_save_config,
button_save_as_config, button_save_as_config,
config_file_name, config_file_name,
button_load_config,
) = gradio_config() ) = gradio_config()
( (
@ -775,7 +781,14 @@ def dreambooth_tab(
button_open_config.click( button_open_config.click(
open_configuration, open_configuration,
inputs=[config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )

View File

@ -149,7 +149,8 @@ def save_configuration(
return file_path return file_path
def open_config_file( def open_configuration(
ask_for_file,
file_path, file_path,
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
@ -218,8 +219,12 @@ def open_config_file(
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path original_file_path = file_path
file_path = get_file_path(file_path)
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None: if not file_path == '' and not file_path == None:
# load variables from JSON file # load variables from JSON file
@ -235,7 +240,7 @@ def open_config_file(
values = [file_path] values = [file_path]
for key, value in parameters: for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']: if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value)) values.append(my_data.get(key, value))
return tuple(values) return tuple(values)
@ -492,8 +497,8 @@ def remove_doublequote(file_path):
def finetune_tab(): def finetune_tab():
dummy_ft_true = gr.Label(value=True, visible=False) dummy_db_true = gr.Label(value=True, visible=False)
dummy_ft_false = gr.Label(value=False, visible=False) dummy_db_false = gr.Label(value=False, visible=False)
gr.Markdown('Train a custom model using kohya finetune python code...') gr.Markdown('Train a custom model using kohya finetune python code...')
( (
@ -501,6 +506,7 @@ def finetune_tab():
button_save_config, button_save_config,
button_save_as_config, button_save_as_config,
config_file_name, config_file_name,
button_load_config,
) = gradio_config() ) = gradio_config()
( (
@ -770,22 +776,29 @@ def finetune_tab():
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)
button_open_config.click( button_open_config.click(
open_config_file, open_configuration,
inputs=[config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )
button_save_config.click( button_save_config.click(
save_configuration, save_configuration,
inputs=[dummy_ft_false, config_file_name] + settings_list, inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name], outputs=[config_file_name],
show_progress=False, show_progress=False,
) )
button_save_as_config.click( button_save_as_config.click(
save_configuration, save_configuration,
inputs=[dummy_ft_true, config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name], outputs=[config_file_name],
show_progress=False, show_progress=False,
) )

View File

@ -405,11 +405,14 @@ def gradio_config():
placeholder="type the configuration file path or use the 'Open' button above to select it...", placeholder="type the configuration file path or use the 'Open' button above to select it...",
interactive=True, interactive=True,
) )
button_load_config = gr.Button('Load 💾', elem_id='open_folder')
config_file_name.change(remove_doublequote, inputs=[config_file_name], outputs=[config_file_name])
return ( return (
button_open_config, button_open_config,
button_save_config, button_save_config,
button_save_as_config, button_save_as_config,
config_file_name, config_file_name,
button_load_config,
) )

View File

@ -168,6 +168,7 @@ def save_configuration(
def open_configuration( def open_configuration(
ask_for_file,
file_path, file_path,
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
@ -240,8 +241,12 @@ def open_configuration(
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path original_file_path = file_path
file_path = get_file_path(file_path)
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None: if not file_path == '' and not file_path == None:
# load variables from JSON file # load variables from JSON file
@ -257,7 +262,7 @@ def open_configuration(
values = [file_path] values = [file_path]
for key, value in parameters: for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']: if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value)) values.append(my_data.get(key, value))
# This next section is about making the LoCon parameters visible if LoRA_type = 'Standard' # This next section is about making the LoCon parameters visible if LoRA_type = 'Standard'
@ -610,6 +615,7 @@ def lora_tab(
button_save_config, button_save_config,
button_save_as_config, button_save_as_config,
config_file_name, config_file_name,
button_load_config,
) = gradio_config() ) = gradio_config()
( (
@ -974,7 +980,14 @@ def lora_tab(
button_open_config.click( button_open_config.click(
open_configuration, open_configuration,
inputs=[config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False,
)
button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row], outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False, show_progress=False,
) )

View File

@ -158,6 +158,7 @@ def save_configuration(
def open_configuration( def open_configuration(
ask_for_file,
file_path, file_path,
pretrained_model_name_or_path, pretrained_model_name_or_path,
v2, v2,
@ -226,8 +227,12 @@ def open_configuration(
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path original_file_path = file_path
file_path = get_file_path(file_path)
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None: if not file_path == '' and not file_path == None:
# load variables from JSON file # load variables from JSON file
@ -243,7 +248,7 @@ def open_configuration(
values = [file_path] values = [file_path]
for key, value in parameters: for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']: if not key in ['ask_for_file', 'file_path']:
values.append(my_data.get(key, value)) values.append(my_data.get(key, value))
return tuple(values) return tuple(values)
@ -548,6 +553,7 @@ def ti_tab(
button_save_config, button_save_config,
button_save_as_config, button_save_as_config,
config_file_name, config_file_name,
button_load_config,
) = gradio_config() ) = gradio_config()
( (
@ -865,7 +871,14 @@ def ti_tab(
button_open_config.click( button_open_config.click(
open_configuration, open_configuration,
inputs=[config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
show_progress=False,
)
button_load_config.click(
open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )