Update presets

This commit is contained in:
bmaltais 2023-03-05 21:42:28 -05:00
parent dbf959db68
commit 9e6b4cb69b

View File

@ -9,34 +9,27 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
# define a list of substrings to search for
all_preset_models = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5',
]
# define a list of substrings to search for v2 base models
substrings_v2 = [
V2_BASE_MODELS = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
]
# define a list of substrings to search for v_parameterization models
substrings_v_parameterization = [
V_PARAMETERIZATION_MODELS = [
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
]
# define a list of substrings to v1.x models
substrings_v1_model = [
V1_MODELS = [
'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5',
]
# define a list of substrings to search for
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
def update_my_data(my_data):
if my_data.get('use_8bit_adam', False) == True:
@ -54,7 +47,7 @@ def update_my_data(my_data):
my_data['model_list'] = 'custom'
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
if not my_data.get('pretrained_model_name_or_path', '') in all_preset_models:
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
my_data['model_list'] = 'custom'
return my_data
@ -319,14 +312,14 @@ def set_pretrained_model_name_or_path_input(
model_list, pretrained_model_name_or_path, v2, v_parameterization
):
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(model_list) in substrings_v2:
if str(model_list) in V2_BASE_MODELS:
print('SD v2 model detected. Setting --v2 parameter')
v2 = True
v_parameterization = False
pretrained_model_name_or_path = str(model_list)
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(model_list) in substrings_v_parameterization:
if str(model_list) in V_PARAMETERIZATION_MODELS:
print(
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
)
@ -334,17 +327,17 @@ def set_pretrained_model_name_or_path_input(
v_parameterization = True
pretrained_model_name_or_path = str(model_list)
if str(model_list) in substrings_v1_model:
if str(model_list) in V1_MODELS:
v2 = False
v_parameterization = False
pretrained_model_name_or_path = str(model_list)
if model_list == 'custom':
if (
str(pretrained_model_name_or_path) in substrings_v1_model
or str(pretrained_model_name_or_path) in substrings_v2
str(pretrained_model_name_or_path) in V1_MODELS
or str(pretrained_model_name_or_path) in V2_BASE_MODELS
or str(pretrained_model_name_or_path)
in substrings_v_parameterization
in V_PARAMETERIZATION_MODELS
):
pretrained_model_name_or_path = ''
v2 = False
@ -355,16 +348,16 @@ def set_v2_checkbox(
model_list, v2, v_parameterization
):
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(model_list) in substrings_v2:
if str(model_list) in V2_BASE_MODELS:
v2 = True
v_parameterization = False
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(model_list) in substrings_v_parameterization:
if str(model_list) in V_PARAMETERIZATION_MODELS:
v2 = True
v_parameterization = True
if str(model_list) in substrings_v1_model:
if str(model_list) in V1_MODELS:
v2 = False
v_parameterization = False
@ -377,7 +370,7 @@ def set_model_list(
v_parameterization,
):
if not pretrained_model_name_or_path in all_preset_models:
if not pretrained_model_name_or_path in ALL_PRESET_MODELS:
model_list = 'custom'
else:
model_list = pretrained_model_name_or_path