Removed one warning dealing with get_file_path()
Using lambdas now to pass in variable amount of arguments from components. This works right now with a few open windows, but saving and possibly loading will be broken right now. They need the lambda treatment next. I also split the JSON validation placeholder to library/common_utilities.py.
This commit is contained in:
parent
160e371be3
commit
e5b83df675
@ -28,8 +28,9 @@ from library.common_gui import (
|
||||
gradio_source_model,
|
||||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist, is_valid_config, show_message_box,
|
||||
check_if_model_exist, show_message_box, get_file_path_gradio_wrapper,
|
||||
)
|
||||
from library.common_utilities import is_valid_config
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
)
|
||||
@ -228,7 +229,8 @@ def open_configuration(
|
||||
original_file_path = file_path
|
||||
|
||||
if ask_for_file:
|
||||
file_path = get_file_path(file_path, filedialog_type="json")
|
||||
print(f"File path: {file_path}")
|
||||
file_path = get_file_path_gradio_wrapper(file_path)
|
||||
|
||||
if not file_path == '' and file_path is not None:
|
||||
with open(file_path, 'r') as f:
|
||||
@ -836,15 +838,15 @@ def dreambooth_tab(
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
open_configuration,
|
||||
lambda *args, **kwargs: open_configuration(*args),
|
||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_load_config.click(
|
||||
open_configuration,
|
||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||
lambda *args, **kwargs: open_configuration(*args),
|
||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
show_progress=False,
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ from library.common_gui import (
|
||||
run_cmd_training,
|
||||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist,
|
||||
check_if_model_exist, get_file_path_gradio_wrapper,
|
||||
)
|
||||
from library.tensorboard_gui import (
|
||||
gradio_tensorboard,
|
||||
@ -231,9 +231,9 @@ def open_configuration(
|
||||
original_file_path = file_path
|
||||
|
||||
if ask_for_file:
|
||||
file_path = get_file_path(file_path)
|
||||
file_path = get_file_path_gradio_wrapper(file_path)
|
||||
|
||||
if not file_path == '' and not file_path == None:
|
||||
if not file_path == '' and file_path is not None:
|
||||
# load variables from JSON file
|
||||
with open(file_path, 'r') as f:
|
||||
my_data = json.load(f)
|
||||
@ -799,14 +799,14 @@ def finetune_tab():
|
||||
button_run.click(train_model, inputs=settings_list)
|
||||
|
||||
button_open_config.click(
|
||||
open_configuration,
|
||||
lambda *args, **kwargs: open_configuration(),
|
||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_load_config.click(
|
||||
open_configuration,
|
||||
lambda *args, **kwargs: open_configuration(),
|
||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
show_progress=False,
|
||||
|
@ -89,23 +89,6 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_valid_config(data):
|
||||
# Check if the data is a dictionary
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
||||
# Add checks for expected keys and valid values
|
||||
# For example, check if 'use_8bit_adam' is a boolean
|
||||
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
|
||||
return False
|
||||
|
||||
# Add more checks for other keys as needed
|
||||
|
||||
# If all checks pass, return True
|
||||
return True
|
||||
|
||||
|
||||
def update_my_data(my_data):
|
||||
# Update the optimizer based on the use_8bit_adam flag
|
||||
use_8bit_adam = my_data.get('use_8bit_adam', False)
|
||||
@ -155,6 +138,24 @@ def update_my_data(my_data):
|
||||
# # If no extension files were found, return False
|
||||
# return False
|
||||
|
||||
def get_file_path_gradio_wrapper(file_path, filedialog_type="all"):
|
||||
file_extension = os.path.splitext(file_path)[-1].lower()
|
||||
|
||||
filetype_filters = {
|
||||
'db': ['.db'],
|
||||
'json': ['.json'],
|
||||
'lora': ['.pt', '.ckpt', '.safetensors'],
|
||||
}
|
||||
|
||||
# Find the appropriate filedialog_type based on the file extension
|
||||
filedialog_type = 'all'
|
||||
for key, extensions in filetype_filters.items():
|
||||
if file_extension in extensions:
|
||||
filedialog_type = key
|
||||
break
|
||||
|
||||
return get_file_path(file_path, filedialog_type)
|
||||
|
||||
|
||||
def get_file_path(file_path='', filedialog_type="lora"):
|
||||
current_file_path = file_path
|
||||
|
14
library/common_utilities.py
Normal file
14
library/common_utilities.py
Normal file
@ -0,0 +1,14 @@
|
||||
def is_valid_config(data):
|
||||
# Check if the data is a dictionary
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
||||
# Add checks for expected keys and valid values
|
||||
# For example, check if 'use_8bit_adam' is a boolean
|
||||
if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool):
|
||||
return False
|
||||
|
||||
# Add more checks for other keys as needed
|
||||
|
||||
# If all checks pass, return True
|
||||
return True
|
@ -4,22 +4,22 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import get_folder_path, get_file_path
|
||||
from .common_gui import get_folder_path, get_file_path, get_file_path_gradio_wrapper
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
|
||||
|
||||
def convert_model(
|
||||
source_model_input,
|
||||
source_model_type,
|
||||
target_model_folder_input,
|
||||
target_model_name_input,
|
||||
target_model_type,
|
||||
target_save_precision_type,
|
||||
source_model_input,
|
||||
source_model_type,
|
||||
target_model_folder_input,
|
||||
target_model_name_input,
|
||||
target_model_type,
|
||||
target_save_precision_type,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if source_model_type == '':
|
||||
@ -61,8 +61,8 @@ def convert_model(
|
||||
run_cmd += f' --{target_save_precision_type}'
|
||||
|
||||
if (
|
||||
target_model_type == 'diffuser'
|
||||
or target_model_type == 'diffuser_safetensors'
|
||||
target_model_type == 'diffuser'
|
||||
or target_model_type == 'diffuser_safetensors'
|
||||
):
|
||||
run_cmd += f' --reference_model="{source_model_type}"'
|
||||
|
||||
@ -72,8 +72,8 @@ def convert_model(
|
||||
run_cmd += f' "{source_model_input}"'
|
||||
|
||||
if (
|
||||
target_model_type == 'diffuser'
|
||||
or target_model_type == 'diffuser_safetensors'
|
||||
target_model_type == 'diffuser'
|
||||
or target_model_type == 'diffuser_safetensors'
|
||||
):
|
||||
target_model_path = os.path.join(
|
||||
target_model_folder_input, target_model_name_input
|
||||
@ -95,8 +95,8 @@ def convert_model(
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
if (
|
||||
not target_model_type == 'diffuser'
|
||||
or target_model_type == 'diffuser_safetensors'
|
||||
not target_model_type == 'diffuser'
|
||||
or target_model_type == 'diffuser_safetensors'
|
||||
):
|
||||
|
||||
v2_models = [
|
||||
@ -180,7 +180,8 @@ def gradio_convert_model_tab():
|
||||
document_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_source_model_file.click(
|
||||
get_file_path,
|
||||
lambda input1, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.normpath(input1)),
|
||||
inputs=[source_model_input],
|
||||
outputs=source_model_input,
|
||||
show_progress=False,
|
||||
|
@ -4,25 +4,25 @@ import subprocess
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path, get_saveasfile_path,
|
||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
|
||||
|
||||
def extract_lora(
|
||||
model_tuned,
|
||||
model_org,
|
||||
save_to,
|
||||
save_precision,
|
||||
dim,
|
||||
v2,
|
||||
conv_dim,
|
||||
device,
|
||||
model_tuned,
|
||||
model_org,
|
||||
save_to,
|
||||
save_precision,
|
||||
dim,
|
||||
v2,
|
||||
conv_dim,
|
||||
device,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model_tuned == '':
|
||||
@ -43,7 +43,7 @@ def extract_lora(
|
||||
return
|
||||
|
||||
run_cmd = (
|
||||
f'{PYTHON} "{os.path.join("networks","extract_lora_from_models.py")}"'
|
||||
f'{PYTHON} "{os.path.join("networks", "extract_lora_from_models.py")}"'
|
||||
)
|
||||
run_cmd += f' --save_precision {save_precision}'
|
||||
run_cmd += f' --save_to "{save_to}"'
|
||||
@ -90,7 +90,8 @@ def gradio_extract_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_model_tuned_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[model_tuned, model_ext, model_ext_name],
|
||||
outputs=model_tuned,
|
||||
show_progress=False,
|
||||
@ -105,7 +106,8 @@ def gradio_extract_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_model_org_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[model_org, model_ext, model_ext_name],
|
||||
outputs=model_org,
|
||||
show_progress=False,
|
||||
|
@ -4,7 +4,7 @@ import subprocess
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path, get_saveasfile_path,
|
||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
@ -136,7 +136,8 @@ def gradio_extract_lycoris_locon_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_db_model_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[db_model, model_ext, model_ext_name],
|
||||
outputs=db_model,
|
||||
show_progress=False,
|
||||
@ -151,7 +152,8 @@ def gradio_extract_lycoris_locon_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_base_model_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[base_model, model_ext, model_ext_name],
|
||||
outputs=base_model,
|
||||
show_progress=False,
|
||||
|
@ -4,7 +4,7 @@ import subprocess
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path, get_saveasfile_path,
|
||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
@ -81,7 +81,8 @@ def gradio_merge_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_a_model,
|
||||
show_progress=False,
|
||||
@ -96,7 +97,8 @@ def gradio_merge_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_b_model,
|
||||
show_progress=False,
|
||||
|
@ -3,7 +3,7 @@ import subprocess
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import get_file_path, get_saveasfile_path
|
||||
from .common_gui import get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper
|
||||
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
@ -96,7 +96,8 @@ def gradio_resize_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[model, lora_ext, lora_ext_name],
|
||||
outputs=model,
|
||||
show_progress=False,
|
||||
|
@ -4,7 +4,7 @@ import subprocess
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path, get_saveasfile_path,
|
||||
get_file_path, get_saveasfile_path, get_file_path_gradio_wrapper,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
@ -87,7 +87,8 @@ def gradio_svd_merge_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_a_model,
|
||||
show_progress=False,
|
||||
@ -102,7 +103,8 @@ def gradio_svd_merge_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_b_model,
|
||||
show_progress=False,
|
||||
|
@ -4,7 +4,7 @@ import subprocess
|
||||
import gradio as gr
|
||||
|
||||
from .common_gui import (
|
||||
get_file_path,
|
||||
get_file_path, get_file_path_gradio_wrapper,
|
||||
)
|
||||
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
@ -68,7 +68,8 @@ def gradio_verify_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_lora_model_file.click(
|
||||
get_file_path,
|
||||
lambda input1, input2, input3, *args, **kwargs:
|
||||
get_file_path_gradio_wrapper(file_path=os.path.join(input1, input2 + input3)),
|
||||
inputs=[lora_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_model,
|
||||
show_progress=False,
|
||||
|
@ -28,7 +28,7 @@ from library.common_gui import (
|
||||
run_cmd_training,
|
||||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist, show_message_box,
|
||||
check_if_model_exist, show_message_box, get_file_path_gradio_wrapper,
|
||||
)
|
||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
@ -254,7 +254,7 @@ def open_configuration(
|
||||
original_file_path = file_path
|
||||
|
||||
if ask_for_file:
|
||||
file_path = get_file_path(file_path)
|
||||
file_path = get_file_path_gradio_wrapper(file_path)
|
||||
|
||||
if not file_path == '' and not file_path == None:
|
||||
# load variables from JSON file
|
||||
@ -1031,14 +1031,14 @@ def lora_tab(
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
open_configuration,
|
||||
lambda *args, **kwargs: open_configuration(),
|
||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list + [LoCon_row],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_load_config.click(
|
||||
open_configuration,
|
||||
lambda *args, **kwargs: open_configuration(),
|
||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list + [LoCon_row],
|
||||
show_progress=False,
|
||||
|
@ -28,7 +28,7 @@ from library.common_gui import (
|
||||
gradio_source_model,
|
||||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist,
|
||||
check_if_model_exist, get_file_path_gradio_wrapper,
|
||||
)
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
@ -240,7 +240,7 @@ def open_configuration(
|
||||
original_file_path = file_path
|
||||
|
||||
if ask_for_file:
|
||||
file_path = get_file_path(file_path)
|
||||
file_path = get_file_path_gradio_wrapper(file_path)
|
||||
|
||||
if not file_path == '' and not file_path == None:
|
||||
# load variables from JSON file
|
||||
@ -673,7 +673,7 @@ def ti_tab(
|
||||
)
|
||||
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
||||
weights_file_input.click(
|
||||
get_file_path,
|
||||
lambda *args, **kwargs: get_file_path_gradio_wrapper,
|
||||
outputs=weights,
|
||||
show_progress=False,
|
||||
)
|
||||
@ -899,14 +899,14 @@ def ti_tab(
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
open_configuration,
|
||||
lambda *args, **kwargs: open_configuration(),
|
||||
inputs=[dummy_db_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_load_config.click(
|
||||
open_configuration,
|
||||
lambda *args, **kwargs: open_configuration(),
|
||||
inputs=[dummy_db_false, config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
show_progress=False,
|
||||
|
Loading…
Reference in New Issue
Block a user