Add logic to v2 checkbox

This commit is contained in:
bmaltais 2023-03-05 21:38:20 -05:00
parent cc7aee2301
commit dbf959db68

View File

@ -10,7 +10,7 @@ save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄 document_symbol = '\U0001F4C4' # 📄
# define a list of substrings to search for # define a list of substrings to search for
preset_models = [ all_preset_models = [
'stabilityai/stable-diffusion-2-1-base', 'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base', 'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1', 'stabilityai/stable-diffusion-2-1',
@ -19,6 +19,24 @@ preset_models = [
'runwayml/stable-diffusion-v1-5', 'runwayml/stable-diffusion-v1-5',
] ]
# define a list of substrings to search for v2 base models
substrings_v2 = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
]
# define a list of substrings to search for v_parameterization models
substrings_v_parameterization = [
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
]
# define a list of substrings to v1.x models
substrings_v1_model = [
'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5',
]
def update_my_data(my_data): def update_my_data(my_data):
if my_data.get('use_8bit_adam', False) == True: if my_data.get('use_8bit_adam', False) == True:
@ -36,7 +54,7 @@ def update_my_data(my_data):
my_data['model_list'] = 'custom' my_data['model_list'] = 'custom'
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom # If Pretrained model name or path is not one of the preset models then set the preset_model to custom
if not my_data.get('pretrained_model_name_or_path', '') in preset_models: if not my_data.get('pretrained_model_name_or_path', '') in all_preset_models:
my_data['model_list'] = 'custom' my_data['model_list'] = 'custom'
return my_data return my_data
@ -300,12 +318,6 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
def set_pretrained_model_name_or_path_input( def set_pretrained_model_name_or_path_input(
model_list, pretrained_model_name_or_path, v2, v_parameterization model_list, pretrained_model_name_or_path, v2, v_parameterization
): ):
# define a list of substrings to search for
substrings_v2 = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
]
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(model_list) in substrings_v2: if str(model_list) in substrings_v2:
print('SD v2 model detected. Setting --v2 parameter') print('SD v2 model detected. Setting --v2 parameter')
@ -313,12 +325,6 @@ def set_pretrained_model_name_or_path_input(
v_parameterization = False v_parameterization = False
pretrained_model_name_or_path = str(model_list) pretrained_model_name_or_path = str(model_list)
# define a list of substrings to search for v-objective
substrings_v_parameterization = [
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
]
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(model_list) in substrings_v_parameterization: if str(model_list) in substrings_v_parameterization:
print( print(
@ -328,12 +334,6 @@ def set_pretrained_model_name_or_path_input(
v_parameterization = True v_parameterization = True
pretrained_model_name_or_path = str(model_list) pretrained_model_name_or_path = str(model_list)
# define a list of substrings to v1.x
substrings_v1_model = [
'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5',
]
if str(model_list) in substrings_v1_model: if str(model_list) in substrings_v1_model:
v2 = False v2 = False
v_parameterization = False v_parameterization = False
@ -351,6 +351,25 @@ def set_pretrained_model_name_or_path_input(
v_parameterization = False v_parameterization = False
return model_list, pretrained_model_name_or_path, v2, v_parameterization return model_list, pretrained_model_name_or_path, v2, v_parameterization
def set_v2_checkbox(
model_list, v2, v_parameterization
):
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(model_list) in substrings_v2:
v2 = True
v_parameterization = False
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(model_list) in substrings_v_parameterization:
v2 = True
v_parameterization = True
if str(model_list) in substrings_v1_model:
v2 = False
v_parameterization = False
return v2, v_parameterization
def set_model_list( def set_model_list(
model_list, model_list,
pretrained_model_name_or_path, pretrained_model_name_or_path,
@ -358,7 +377,7 @@ def set_model_list(
v_parameterization, v_parameterization,
): ):
if not pretrained_model_name_or_path in preset_models: if not pretrained_model_name_or_path in all_preset_models:
model_list = 'custom' model_list = 'custom'
else: else:
model_list = pretrained_model_name_or_path model_list = pretrained_model_name_or_path
@ -458,6 +477,8 @@ def gradio_source_model():
v_parameterization = gr.Checkbox( v_parameterization = gr.Checkbox(
label='v_parameterization', value=False label='v_parameterization', value=False
) )
v2.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False)
v_parameterization.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False)
model_list.change( model_list.change(
set_pretrained_model_name_or_path_input, set_pretrained_model_name_or_path_input,
inputs=[ inputs=[
@ -472,6 +493,7 @@ def gradio_source_model():
v2, v2,
v_parameterization, v_parameterization,
], ],
show_progress=False,
) )
# Update the model list and parameters when user click outside the button or field # Update the model list and parameters when user click outside the button or field
pretrained_model_name_or_path.change( pretrained_model_name_or_path.change(
@ -487,6 +509,7 @@ def gradio_source_model():
v2, v2,
v_parameterization, v_parameterization,
], ],
show_progress=False,
) )
return ( return (
pretrained_model_name_or_path, pretrained_model_name_or_path,