Removed one warning dealing with get_file_path()

Using lambdas now to pass in variable amount of arguments from components. This works right now with a few open windows, but saving and possibly loading will be broken right now. They need the lambda treatment next.

I also split the JSON validation placeholder to library/common_utilities.py.
This commit is contained in:
JSTayco 2023-03-30 13:13:25 -07:00
parent 160e371be3
commit e5b83df675
13 changed files with 105 additions and 77 deletions

View File

@ -28,8 +28,9 @@ from library.common_gui import (
gradio_source_model, gradio_source_model,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist, is_valid_config, show_message_box, check_if_model_exist, show_message_box, get_file_path_gradio_wrapper,
) )
from library.common_utilities import is_valid_config
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
) )
@ -228,7 +229,8 @@ def open_configuration(
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path(file_path, filedialog_type="json") print(f"File path: {file_path}")
file_path = get_file_path_gradio_wrapper(file_path)
if not file_path == '' and file_path is not None: if not file_path == '' and file_path is not None:
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
@ -836,15 +838,15 @@ def dreambooth_tab(
] ]
button_open_config.click( button_open_config.click(
open_configuration, lambda *args, **kwargs: open_configuration(*args),
inputs=[dummy_db_true, config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )
button_load_config.click( button_load_config.click(
open_configuration, lambda *args, **kwargs: open_configuration(*args),
inputs=[dummy_db_false, config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )

View File

@ -20,7 +20,7 @@ from library.common_gui import (
run_cmd_training, run_cmd_training,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist, check_if_model_exist, get_file_path_gradio_wrapper,
) )
from library.tensorboard_gui import ( from library.tensorboard_gui import (
gradio_tensorboard, gradio_tensorboard,
@ -231,9 +231,9 @@ def open_configuration(
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path(file_path) file_path = get_file_path_gradio_wrapper(file_path)
if not file_path == '' and not file_path == None: if not file_path == '' and file_path is not None:
# load variables from JSON file # load variables from JSON file
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
my_data = json.load(f) my_data = json.load(f)
@ -799,14 +799,14 @@ def finetune_tab():
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)
button_open_config.click( button_open_config.click(
open_configuration, lambda *args, **kwargs: open_configuration(),
inputs=[dummy_db_true, config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )
button_load_config.click( button_load_config.click(
open_configuration, lambda *args, **kwargs: open_configuration(),
inputs=[dummy_db_false, config_file_name] + settings_list, inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,

View File

@ -89,23 +89,6 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
return False return False
def is_valid_config(data):
# Check if the data is a dictionary
if not isinstance(data, dict):
return False
# Add checks for expected keys and valid values
# For example, check if 'use_8bit_adam' is a boolean
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
return False
# Add more checks for other keys as needed
# If all checks pass, return True
return True
def update_my_data(my_data): def update_my_data(my_data):
# Update the optimizer based on the use_8bit_adam flag # Update the optimizer based on the use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False) use_8bit_adam = my_data.get('use_8bit_adam', False)
@ -155,6 +138,24 @@ def update_my_data(my_data):
# # If no extension files were found, return False # # If no extension files were found, return False
# return False # return False
def get_file_path_gradio_wrapper(file_path, filedialog_type="all"):
file_extension = os.path.splitext(file_path)[-1].lower()
filetype_filters = {
'db': ['.db'],
'json': ['.json'],
'lora': ['.pt', '.ckpt', '.safetensors'],
}
# Find the appropriate filedialog_type based on the file extension
filedialog_type = 'all'
for key, extensions in filetype_filters.items():
if file_extension in extensions:
filedialog_type = key
break
return get_file_path(file_path, filedialog_type)
def get_file_path(file_path='', filedialog_type="lora"): def get_file_path(file_path='', filedialog_type="lora"):
current_file_path = file_path current_file_path = file_path

View File

@ -0,0 +1,14 @@
def is_valid_config(data):
# Check if the data is a dictionary
if not isinstance(data, dict):
return False
# Add checks for expected keys and valid values
# For example, check if 'use_8bit_adam' is a boolean
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
return False
# Add more checks for other keys as needed
# If all checks pass, return True
return True

View File

@ -4,22 +4,22 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import get_folder_path, get_file_path from .common_gui import get_folder_path, get_file_path, get_file_path_gradio_wrapper
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾 save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄 document_symbol = '\U0001F4C4' # 📄
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
def convert_model( def convert_model(
source_model_input, source_model_input,
source_model_type, source_model_type,
target_model_folder_input, target_model_folder_input,
target_model_name_input, target_model_name_input,
target_model_type, target_model_type,
target_save_precision_type, target_save_precision_type,
): ):
# Check for caption_text_input # Check for caption_text_input
if source_model_type == '': if source_model_type == '':
@ -61,8 +61,8 @@ def convert_model(
run_cmd += f' --{target_save_precision_type}' run_cmd += f' --{target_save_precision_type}'
if ( if (
target_model_type == 'diffuser' target_model_type == 'diffuser'
or target_model_type == 'diffuser_safetensors' or target_model_type == 'diffuser_safetensors'
): ):
run_cmd += f' --reference_model="{source_model_type}"' run_cmd += f' --reference_model="{source_model_type}"'
@ -72,8 +72,8 @@ def convert_model(
run_cmd += f' "{source_model_input}"' run_cmd += f' "{source_model_input}"'
if ( if (
target_model_type == 'diffuser' target_model_type == 'diffuser'
or target_model_type == 'diffuser_safetensors' or target_model_type == 'diffuser_safetensors'
): ):
target_model_path = os.path.join( target_model_path = os.path.join(
target_model_folder_input, target_model_name_input target_model_folder_input, target_model_name_input
@ -95,8 +95,8 @@ def convert_model(
subprocess.run(run_cmd) subprocess.run(run_cmd)
if ( if (
not target_model_type == 'diffuser' not target_model_type == 'diffuser'
or target_model_type == 'diffuser_safetensors' or target_model_type == 'diffuser_safetensors'
): ):
v2_models = [ v2_models = [
@ -180,7 +180,8 @@ def gradio_convert_model_tab():
document_symbol, elem_id='open_folder_small' document_symbol, elem_id='open_folder_small'
) )
button_source_model_file.click( button_source_model_file.click(
get_file_path, lambda input1, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.normpath(input1)),
inputs=[source_model_input], inputs=[source_model_input],
outputs=source_model_input, outputs=source_model_input,
show_progress=False, show_progress=False,

View File

@ -4,25 +4,25 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui import (
get_file_path, get_saveasfile_path, get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
) )
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾 save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄 document_symbol = '\U0001F4C4' # 📄
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
def extract_lora( def extract_lora(
model_tuned, model_tuned,
model_org, model_org,
save_to, save_to,
save_precision, save_precision,
dim, dim,
v2, v2,
conv_dim, conv_dim,
device, device,
): ):
# Check for caption_text_input # Check for caption_text_input
if model_tuned == '': if model_tuned == '':
@ -43,7 +43,7 @@ def extract_lora(
return return
run_cmd = ( run_cmd = (
f'{PYTHON} "{os.path.join("networks","extract_lora_from_models.py")}"' f'{PYTHON} "{os.path.join("networks", "extract_lora_from_models.py")}"'
) )
run_cmd += f' --save_precision {save_precision}' run_cmd += f' --save_precision {save_precision}'
run_cmd += f' --save_to "{save_to}"' run_cmd += f' --save_to "{save_to}"'
@ -90,7 +90,8 @@ def gradio_extract_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_model_tuned_file.click( button_model_tuned_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[model_tuned, model_ext, model_ext_name], inputs=[model_tuned, model_ext, model_ext_name],
outputs=model_tuned, outputs=model_tuned,
show_progress=False, show_progress=False,
@ -105,7 +106,8 @@ def gradio_extract_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_model_org_file.click( button_model_org_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[model_org, model_ext, model_ext_name], inputs=[model_org, model_ext, model_ext_name],
outputs=model_org, outputs=model_org,
show_progress=False, show_progress=False,

View File

@ -4,7 +4,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui import (
get_file_path, get_saveasfile_path, get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
) )
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -136,7 +136,8 @@ def gradio_extract_lycoris_locon_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_db_model_file.click( button_db_model_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[db_model, model_ext, model_ext_name], inputs=[db_model, model_ext, model_ext_name],
outputs=db_model, outputs=db_model,
show_progress=False, show_progress=False,
@ -151,7 +152,8 @@ def gradio_extract_lycoris_locon_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_base_model_file.click( button_base_model_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[base_model, model_ext, model_ext_name], inputs=[base_model, model_ext, model_ext_name],
outputs=base_model, outputs=base_model,
show_progress=False, show_progress=False,

View File

@ -4,7 +4,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui import (
get_file_path, get_saveasfile_path, get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
) )
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -81,7 +81,8 @@ def gradio_merge_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_a_model_file.click( button_lora_a_model_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_a_model, lora_ext, lora_ext_name], inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model, outputs=lora_a_model,
show_progress=False, show_progress=False,
@ -96,7 +97,8 @@ def gradio_merge_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_b_model_file.click( button_lora_b_model_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_b_model, lora_ext, lora_ext_name], inputs=[lora_b_model, lora_ext, lora_ext_name],
outputs=lora_b_model, outputs=lora_b_model,
show_progress=False, show_progress=False,

View File

@ -3,7 +3,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import get_file_path, get_saveasfile_path from .common_gui import get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -96,7 +96,8 @@ def gradio_resize_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_a_model_file.click( button_lora_a_model_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[model, lora_ext, lora_ext_name], inputs=[model, lora_ext, lora_ext_name],
outputs=model, outputs=model,
show_progress=False, show_progress=False,

View File

@ -4,7 +4,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui import (
get_file_path, get_saveasfile_path, get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
) )
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -87,7 +87,8 @@ def gradio_svd_merge_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_a_model_file.click( button_lora_a_model_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_a_model, lora_ext, lora_ext_name], inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model, outputs=lora_a_model,
show_progress=False, show_progress=False,
@ -102,7 +103,8 @@ def gradio_svd_merge_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_b_model_file.click( button_lora_b_model_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_b_model, lora_ext, lora_ext_name], inputs=[lora_b_model, lora_ext, lora_ext_name],
outputs=lora_b_model, outputs=lora_b_model,
show_progress=False, show_progress=False,

View File

@ -4,7 +4,7 @@ import subprocess
import gradio as gr import gradio as gr
from .common_gui import ( from .common_gui import (
get_file_path, get_file_path, get_file_path_gradio_wrapper,
) )
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
@ -68,7 +68,8 @@ def gradio_verify_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_lora_model_file.click( button_lora_model_file.click(
get_file_path, lambda input1, input2, input3, *args, **kwargs:
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
inputs=[lora_model, lora_ext, lora_ext_name], inputs=[lora_model, lora_ext, lora_ext_name],
outputs=lora_model, outputs=lora_model,
show_progress=False, show_progress=False,

View File

@ -28,7 +28,7 @@ from library.common_gui import (
run_cmd_training, run_cmd_training,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist, show_message_box, check_if_model_exist, show_message_box, get_file_path_gradio_wrapper,
) )
from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
@ -254,7 +254,7 @@ def open_configuration(
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path(file_path) file_path = get_file_path_gradio_wrapper(file_path)
if not file_path == '' and not file_path == None: if not file_path == '' and not file_path == None:
# load variables from JSON file # load variables from JSON file
@ -1031,14 +1031,14 @@ def lora_tab(
] ]
button_open_config.click( button_open_config.click(
open_configuration, lambda *args, **kwargs: open_configuration(),
inputs=[dummy_db_true, config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row], outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False, show_progress=False,
) )
button_load_config.click( button_load_config.click(
open_configuration, lambda *args, **kwargs: open_configuration(),
inputs=[dummy_db_false, config_file_name] + settings_list, inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list + [LoCon_row], outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False, show_progress=False,

View File

@ -28,7 +28,7 @@ from library.common_gui import (
gradio_source_model, gradio_source_model,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist, check_if_model_exist, get_file_path_gradio_wrapper,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -240,7 +240,7 @@ def open_configuration(
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path(file_path) file_path = get_file_path_gradio_wrapper(file_path)
if not file_path == '' and not file_path == None: if not file_path == '' and not file_path == None:
# load variables from JSON file # load variables from JSON file
@ -673,7 +673,7 @@ def ti_tab(
) )
weights_file_input = gr.Button('📂', elem_id='open_folder_small') weights_file_input = gr.Button('📂', elem_id='open_folder_small')
weights_file_input.click( weights_file_input.click(
get_file_path, lambda *args, **kwargs: get_file_path_gradio_wrapper,
outputs=weights, outputs=weights,
show_progress=False, show_progress=False,
) )
@ -899,14 +899,14 @@ def ti_tab(
] ]
button_open_config.click( button_open_config.click(
open_configuration, lambda *args, **kwargs: open_configuration(),
inputs=[dummy_db_true, config_file_name] + settings_list, inputs=[dummy_db_true, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )
button_load_config.click( button_load_config.click(
open_configuration, lambda *args, **kwargs: open_configuration(),
inputs=[dummy_db_false, config_file_name] + settings_list, inputs=[dummy_db_false, config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,