Add logic to v2 checkbox

This commit is contained in:
bmaltais 2023-03-05 21:38:20 -05:00
parent cc7aee2301
commit dbf959db68

View File

@ -10,7 +10,7 @@ save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
# define a list of substrings to search for
preset_models = [
all_preset_models = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1',
@ -19,6 +19,24 @@ preset_models = [
'runwayml/stable-diffusion-v1-5',
]
# define a list of substrings to search for v2 base models
substrings_v2 = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
]
# define a list of substrings to search for v_parameterization models
substrings_v_parameterization = [
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
]
# define a list of substrings to v1.x models
substrings_v1_model = [
'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5',
]
def update_my_data(my_data):
if my_data.get('use_8bit_adam', False) == True:
@ -36,7 +54,7 @@ def update_my_data(my_data):
my_data['model_list'] = 'custom'
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
if not my_data.get('pretrained_model_name_or_path', '') in preset_models:
if not my_data.get('pretrained_model_name_or_path', '') in all_preset_models:
my_data['model_list'] = 'custom'
return my_data
@ -300,12 +318,6 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
def set_pretrained_model_name_or_path_input(
model_list, pretrained_model_name_or_path, v2, v_parameterization
):
# define a list of substrings to search for
substrings_v2 = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
]
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(model_list) in substrings_v2:
print('SD v2 model detected. Setting --v2 parameter')
@ -313,12 +325,6 @@ def set_pretrained_model_name_or_path_input(
v_parameterization = False
pretrained_model_name_or_path = str(model_list)
# define a list of substrings to search for v-objective
substrings_v_parameterization = [
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
]
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(model_list) in substrings_v_parameterization:
print(
@ -328,12 +334,6 @@ def set_pretrained_model_name_or_path_input(
v_parameterization = True
pretrained_model_name_or_path = str(model_list)
# define a list of substrings to v1.x
substrings_v1_model = [
'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5',
]
if str(model_list) in substrings_v1_model:
v2 = False
v_parameterization = False
@ -351,6 +351,25 @@ def set_pretrained_model_name_or_path_input(
v_parameterization = False
return model_list, pretrained_model_name_or_path, v2, v_parameterization
def set_v2_checkbox(
model_list, v2, v_parameterization
):
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(model_list) in substrings_v2:
v2 = True
v_parameterization = False
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(model_list) in substrings_v_parameterization:
v2 = True
v_parameterization = True
if str(model_list) in substrings_v1_model:
v2 = False
v_parameterization = False
return v2, v_parameterization
def set_model_list(
model_list,
pretrained_model_name_or_path,
@ -358,7 +377,7 @@ def set_model_list(
v_parameterization,
):
if not pretrained_model_name_or_path in preset_models:
if not pretrained_model_name_or_path in all_preset_models:
model_list = 'custom'
else:
model_list = pretrained_model_name_or_path
@ -458,6 +477,8 @@ def gradio_source_model():
v_parameterization = gr.Checkbox(
label='v_parameterization', value=False
)
v2.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False)
v_parameterization.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False)
model_list.change(
set_pretrained_model_name_or_path_input,
inputs=[
@ -472,6 +493,7 @@ def gradio_source_model():
v2,
v_parameterization,
],
show_progress=False,
)
# Update the model list and parameters when user click outside the button or field
pretrained_model_name_or_path.change(
@ -487,6 +509,7 @@ def gradio_source_model():
v2,
v_parameterization,
],
show_progress=False,
)
return (
pretrained_model_name_or_path,