From dbf959db6897a6614341cffe86c718f124e5789d Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 5 Mar 2023 21:38:20 -0500 Subject: [PATCH] Add logic to v2 checkbox --- library/common_gui.py | 65 +++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/library/common_gui.py b/library/common_gui.py index 58131b4..d196728 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -10,7 +10,7 @@ save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 # define a list of substrings to search for -preset_models = [ +all_preset_models = [ 'stabilityai/stable-diffusion-2-1-base', 'stabilityai/stable-diffusion-2-base', 'stabilityai/stable-diffusion-2-1', @@ -19,6 +19,24 @@ preset_models = [ 'runwayml/stable-diffusion-v1-5', ] +# define a list of substrings to search for v2 base models +substrings_v2 = [ + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', +] + +# define a list of substrings to search for v_parameterization models +substrings_v_parameterization = [ + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', +] + +# define a list of substrings to v1.x models +substrings_v1_model = [ + 'CompVis/stable-diffusion-v1-4', + 'runwayml/stable-diffusion-v1-5', +] + def update_my_data(my_data): if my_data.get('use_8bit_adam', False) == True: @@ -36,7 +54,7 @@ def update_my_data(my_data): my_data['model_list'] = 'custom' # If Pretrained model name or path is not one of the preset models then set the preset_model to custom - if not my_data.get('pretrained_model_name_or_path', '') in preset_models: + if not my_data.get('pretrained_model_name_or_path', '') in all_preset_models: my_data['model_list'] = 'custom' return my_data @@ -300,12 +318,6 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): def set_pretrained_model_name_or_path_input( model_list, pretrained_model_name_or_path, v2, v_parameterization ): - # define a list of substrings to search for - substrings_v2 = [ - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - ] - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list if str(model_list) in substrings_v2: print('SD v2 model detected. Setting --v2 parameter') @@ -313,12 +325,6 @@ def set_pretrained_model_name_or_path_input( v_parameterization = False pretrained_model_name_or_path = str(model_list) - # define a list of substrings to search for v-objective - substrings_v_parameterization = [ - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - ] - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list if str(model_list) in substrings_v_parameterization: print( @@ -328,12 +334,6 @@ def set_pretrained_model_name_or_path_input( v_parameterization = True pretrained_model_name_or_path = str(model_list) - # define a list of substrings to v1.x - substrings_v1_model = [ - 'CompVis/stable-diffusion-v1-4', - 'runwayml/stable-diffusion-v1-5', - ] - if str(model_list) in substrings_v1_model: v2 = False v_parameterization = False @@ -351,6 +351,25 @@ def set_pretrained_model_name_or_path_input( v_parameterization = False return model_list, pretrained_model_name_or_path, v2, v_parameterization +def set_v2_checkbox( + model_list, v2, v_parameterization +): + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list + if str(model_list) in substrings_v2: + v2 = True + v_parameterization = False + + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list + if str(model_list) in substrings_v_parameterization: + v2 = True + v_parameterization = True + + if str(model_list) in substrings_v1_model: + v2 = False + v_parameterization = False + + return v2, v_parameterization + def set_model_list( model_list, pretrained_model_name_or_path, @@ -358,7 +377,7 @@ def set_model_list( v_parameterization, ): - if not pretrained_model_name_or_path in preset_models: + if not pretrained_model_name_or_path in all_preset_models: model_list = 'custom' else: model_list = pretrained_model_name_or_path @@ -458,6 +477,8 @@ def gradio_source_model(): v_parameterization = gr.Checkbox( label='v_parameterization', value=False ) + v2.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False) + v_parameterization.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False) model_list.change( set_pretrained_model_name_or_path_input, inputs=[ @@ -472,6 +493,7 @@ def gradio_source_model(): v2, v_parameterization, ], + show_progress=False, ) # Update the model list and parameters when user click outside the button or field pretrained_model_name_or_path.change( @@ -487,6 +509,7 @@ def gradio_source_model(): v2, v_parameterization, ], + show_progress=False, ) return ( pretrained_model_name_or_path,