diff --git a/library/common_gui.py b/library/common_gui.py index d196728..5963acc 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -9,34 +9,27 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 -# define a list of substrings to search for -all_preset_models = [ - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - 'CompVis/stable-diffusion-v1-4', - 'runwayml/stable-diffusion-v1-5', -] - # define a list of substrings to search for v2 base models -substrings_v2 = [ +V2_BASE_MODELS = [ 'stabilityai/stable-diffusion-2-1-base', 'stabilityai/stable-diffusion-2-base', ] # define a list of substrings to search for v_parameterization models -substrings_v_parameterization = [ +V_PARAMETERIZATION_MODELS = [ 'stabilityai/stable-diffusion-2-1', 'stabilityai/stable-diffusion-2', ] # define a list of substrings to v1.x models -substrings_v1_model = [ +V1_MODELS = [ 'CompVis/stable-diffusion-v1-4', 'runwayml/stable-diffusion-v1-5', ] +# define a list of substrings to search for +ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + def update_my_data(my_data): if my_data.get('use_8bit_adam', False) == True: @@ -54,7 +47,7 @@ def update_my_data(my_data): my_data['model_list'] = 'custom' # If Pretrained model name or path is not one of the preset models then set the preset_model to custom - if not my_data.get('pretrained_model_name_or_path', '') in all_preset_models: + if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS: my_data['model_list'] = 'custom' return my_data @@ -319,14 +312,14 @@ def set_pretrained_model_name_or_path_input( model_list, pretrained_model_name_or_path, v2, v_parameterization ): # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list - if str(model_list) in substrings_v2: + if str(model_list) in V2_BASE_MODELS: print('SD v2 model detected. Setting --v2 parameter') v2 = True v_parameterization = False pretrained_model_name_or_path = str(model_list) # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list - if str(model_list) in substrings_v_parameterization: + if str(model_list) in V_PARAMETERIZATION_MODELS: print( 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' ) @@ -334,17 +327,17 @@ def set_pretrained_model_name_or_path_input( v_parameterization = True pretrained_model_name_or_path = str(model_list) - if str(model_list) in substrings_v1_model: + if str(model_list) in V1_MODELS: v2 = False v_parameterization = False pretrained_model_name_or_path = str(model_list) if model_list == 'custom': if ( - str(pretrained_model_name_or_path) in substrings_v1_model - or str(pretrained_model_name_or_path) in substrings_v2 + str(pretrained_model_name_or_path) in V1_MODELS + or str(pretrained_model_name_or_path) in V2_BASE_MODELS or str(pretrained_model_name_or_path) - in substrings_v_parameterization + in V_PARAMETERIZATION_MODELS ): pretrained_model_name_or_path = '' v2 = False @@ -355,16 +348,16 @@ def set_v2_checkbox( model_list, v2, v_parameterization ): # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list - if str(model_list) in substrings_v2: + if str(model_list) in V2_BASE_MODELS: v2 = True v_parameterization = False # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list - if str(model_list) in substrings_v_parameterization: + if str(model_list) in V_PARAMETERIZATION_MODELS: v2 = True v_parameterization = True - if str(model_list) in substrings_v1_model: + if str(model_list) in V1_MODELS: v2 = False v_parameterization = False @@ -377,7 +370,7 @@ def set_model_list( v_parameterization, ): - if not pretrained_model_name_or_path in all_preset_models: + if not pretrained_model_name_or_path in ALL_PRESET_MODELS: model_list = 'custom' else: model_list = pretrained_model_name_or_path