Add logic to v2 checkbox
This commit is contained in:
parent
cc7aee2301
commit
dbf959db68
@ -10,7 +10,7 @@ save_style_symbol = '\U0001f4be' # 💾
|
|||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
# define a list of substrings to search for
|
# define a list of substrings to search for
|
||||||
preset_models = [
|
all_preset_models = [
|
||||||
'stabilityai/stable-diffusion-2-1-base',
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
'stabilityai/stable-diffusion-2-base',
|
'stabilityai/stable-diffusion-2-base',
|
||||||
'stabilityai/stable-diffusion-2-1',
|
'stabilityai/stable-diffusion-2-1',
|
||||||
@ -19,6 +19,24 @@ preset_models = [
|
|||||||
'runwayml/stable-diffusion-v1-5',
|
'runwayml/stable-diffusion-v1-5',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# define a list of substrings to search for v2 base models
|
||||||
|
substrings_v2 = [
|
||||||
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
|
'stabilityai/stable-diffusion-2-base',
|
||||||
|
]
|
||||||
|
|
||||||
|
# define a list of substrings to search for v_parameterization models
|
||||||
|
substrings_v_parameterization = [
|
||||||
|
'stabilityai/stable-diffusion-2-1',
|
||||||
|
'stabilityai/stable-diffusion-2',
|
||||||
|
]
|
||||||
|
|
||||||
|
# define a list of substrings to v1.x models
|
||||||
|
substrings_v1_model = [
|
||||||
|
'CompVis/stable-diffusion-v1-4',
|
||||||
|
'runwayml/stable-diffusion-v1-5',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def update_my_data(my_data):
|
def update_my_data(my_data):
|
||||||
if my_data.get('use_8bit_adam', False) == True:
|
if my_data.get('use_8bit_adam', False) == True:
|
||||||
@ -36,7 +54,7 @@ def update_my_data(my_data):
|
|||||||
my_data['model_list'] = 'custom'
|
my_data['model_list'] = 'custom'
|
||||||
|
|
||||||
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
||||||
if not my_data.get('pretrained_model_name_or_path', '') in preset_models:
|
if not my_data.get('pretrained_model_name_or_path', '') in all_preset_models:
|
||||||
my_data['model_list'] = 'custom'
|
my_data['model_list'] = 'custom'
|
||||||
|
|
||||||
return my_data
|
return my_data
|
||||||
@ -300,12 +318,6 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
|||||||
def set_pretrained_model_name_or_path_input(
|
def set_pretrained_model_name_or_path_input(
|
||||||
model_list, pretrained_model_name_or_path, v2, v_parameterization
|
model_list, pretrained_model_name_or_path, v2, v_parameterization
|
||||||
):
|
):
|
||||||
# define a list of substrings to search for
|
|
||||||
substrings_v2 = [
|
|
||||||
'stabilityai/stable-diffusion-2-1-base',
|
|
||||||
'stabilityai/stable-diffusion-2-base',
|
|
||||||
]
|
|
||||||
|
|
||||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||||
if str(model_list) in substrings_v2:
|
if str(model_list) in substrings_v2:
|
||||||
print('SD v2 model detected. Setting --v2 parameter')
|
print('SD v2 model detected. Setting --v2 parameter')
|
||||||
@ -313,12 +325,6 @@ def set_pretrained_model_name_or_path_input(
|
|||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
pretrained_model_name_or_path = str(model_list)
|
pretrained_model_name_or_path = str(model_list)
|
||||||
|
|
||||||
# define a list of substrings to search for v-objective
|
|
||||||
substrings_v_parameterization = [
|
|
||||||
'stabilityai/stable-diffusion-2-1',
|
|
||||||
'stabilityai/stable-diffusion-2',
|
|
||||||
]
|
|
||||||
|
|
||||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
||||||
if str(model_list) in substrings_v_parameterization:
|
if str(model_list) in substrings_v_parameterization:
|
||||||
print(
|
print(
|
||||||
@ -328,12 +334,6 @@ def set_pretrained_model_name_or_path_input(
|
|||||||
v_parameterization = True
|
v_parameterization = True
|
||||||
pretrained_model_name_or_path = str(model_list)
|
pretrained_model_name_or_path = str(model_list)
|
||||||
|
|
||||||
# define a list of substrings to v1.x
|
|
||||||
substrings_v1_model = [
|
|
||||||
'CompVis/stable-diffusion-v1-4',
|
|
||||||
'runwayml/stable-diffusion-v1-5',
|
|
||||||
]
|
|
||||||
|
|
||||||
if str(model_list) in substrings_v1_model:
|
if str(model_list) in substrings_v1_model:
|
||||||
v2 = False
|
v2 = False
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
@ -351,6 +351,25 @@ def set_pretrained_model_name_or_path_input(
|
|||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
return model_list, pretrained_model_name_or_path, v2, v_parameterization
|
return model_list, pretrained_model_name_or_path, v2, v_parameterization
|
||||||
|
|
||||||
|
def set_v2_checkbox(
|
||||||
|
model_list, v2, v_parameterization
|
||||||
|
):
|
||||||
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||||
|
if str(model_list) in substrings_v2:
|
||||||
|
v2 = True
|
||||||
|
v_parameterization = False
|
||||||
|
|
||||||
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
||||||
|
if str(model_list) in substrings_v_parameterization:
|
||||||
|
v2 = True
|
||||||
|
v_parameterization = True
|
||||||
|
|
||||||
|
if str(model_list) in substrings_v1_model:
|
||||||
|
v2 = False
|
||||||
|
v_parameterization = False
|
||||||
|
|
||||||
|
return v2, v_parameterization
|
||||||
|
|
||||||
def set_model_list(
|
def set_model_list(
|
||||||
model_list,
|
model_list,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
@ -358,7 +377,7 @@ def set_model_list(
|
|||||||
v_parameterization,
|
v_parameterization,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not pretrained_model_name_or_path in preset_models:
|
if not pretrained_model_name_or_path in all_preset_models:
|
||||||
model_list = 'custom'
|
model_list = 'custom'
|
||||||
else:
|
else:
|
||||||
model_list = pretrained_model_name_or_path
|
model_list = pretrained_model_name_or_path
|
||||||
@ -458,6 +477,8 @@ def gradio_source_model():
|
|||||||
v_parameterization = gr.Checkbox(
|
v_parameterization = gr.Checkbox(
|
||||||
label='v_parameterization', value=False
|
label='v_parameterization', value=False
|
||||||
)
|
)
|
||||||
|
v2.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False)
|
||||||
|
v_parameterization.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False)
|
||||||
model_list.change(
|
model_list.change(
|
||||||
set_pretrained_model_name_or_path_input,
|
set_pretrained_model_name_or_path_input,
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -472,6 +493,7 @@ def gradio_source_model():
|
|||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
],
|
],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
# Update the model list and parameters when user click outside the button or field
|
# Update the model list and parameters when user click outside the button or field
|
||||||
pretrained_model_name_or_path.change(
|
pretrained_model_name_or_path.change(
|
||||||
@ -487,6 +509,7 @@ def gradio_source_model():
|
|||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
],
|
],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
|
Loading…
Reference in New Issue
Block a user