From cc7aee23019053cce371e8a98f598238ebca32d9 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 5 Mar 2023 21:10:24 -0500 Subject: [PATCH] Improve custom preset handling --- README.md | 3 ++ library/common_gui.py | 79 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 08a706f..4a0fc1e 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,9 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/03/05 (v21.1.5): + - Add replace underscore with space option to WD14 captioning. Thanks @sALTaccount! + - Improve how custom preset is set and handles. Still not perfect but better. * 2023/03/05 (v21.1.4): - Removing legacy and confusing use 8bit adam chackbox. It is now configured using the Optimiser drop down list. It will be set properly based on legacy config files. * 2023/03/04 (v21.1.3): diff --git a/library/common_gui.py b/library/common_gui.py index f792ffc..58131b4 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -9,18 +9,36 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +# define a list of substrings to search for +preset_models = [ + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + 'CompVis/stable-diffusion-v1-4', + 'runwayml/stable-diffusion-v1-5', +] + def update_my_data(my_data): if my_data.get('use_8bit_adam', False) == True: my_data['optimizer'] = 'AdamW8bit' # my_data['use_8bit_adam'] = False - - if my_data.get('optimizer', 'missing') == 'missing' and my_data.get('use_8bit_adam', False) == False: + + if ( + my_data.get('optimizer', 'missing') == 'missing' + and my_data.get('use_8bit_adam', False) == False + ): my_data['optimizer'] = 'AdamW' if my_data.get('model_list', 'custom') == []: print('Old config with empty model list. Setting to custom...') my_data['model_list'] = 'custom' + + # If Pretrained model name or path is not one of the preset models then set the preset_model to custom + if not my_data.get('pretrained_model_name_or_path', '') in preset_models: + my_data['model_list'] = 'custom' + return my_data @@ -293,8 +311,7 @@ def set_pretrained_model_name_or_path_input( print('SD v2 model detected. Setting --v2 parameter') v2 = True v_parameterization = False - - return model_list, v2, v_parameterization + pretrained_model_name_or_path = str(model_list) # define a list of substrings to search for v-objective substrings_v_parameterization = [ @@ -309,8 +326,7 @@ def set_pretrained_model_name_or_path_input( ) v2 = True v_parameterization = True - - return model_list, v2, v_parameterization + pretrained_model_name_or_path = str(model_list) # define a list of substrings to v1.x substrings_v1_model = [ @@ -321,8 +337,7 @@ def set_pretrained_model_name_or_path_input( if str(model_list) in substrings_v1_model: v2 = False v_parameterization = False - - return model_list, v2, v_parameterization + pretrained_model_name_or_path = str(model_list) if model_list == 'custom': if ( @@ -334,11 +349,26 @@ def set_pretrained_model_name_or_path_input( pretrained_model_name_or_path = '' v2 = False v_parameterization = False - return pretrained_model_name_or_path, v2, v_parameterization + return model_list, pretrained_model_name_or_path, v2, v_parameterization - ### - ### Gradio common GUI section - ### +def set_model_list( + model_list, + pretrained_model_name_or_path, + v2, + v_parameterization, +): + + if not pretrained_model_name_or_path in preset_models: + model_list = 'custom' + else: + model_list = pretrained_model_name_or_path + + return model_list, v2, v_parameterization + + +### +### Gradio common GUI section +### def gradio_config(): @@ -362,6 +392,15 @@ def gradio_config(): ) +def get_pretrained_model_name_or_path_file( + model_list, pretrained_model_name_or_path +): + pretrained_model_name_or_path = get_any_file_path( + pretrained_model_name_or_path + ) + set_model_list(model_list, pretrained_model_name_or_path) + + def gradio_source_model(): with gr.Tab('Source model'): # Define the input elements @@ -428,11 +467,27 @@ def gradio_source_model(): v_parameterization, ], outputs=[ + model_list, pretrained_model_name_or_path, v2, v_parameterization, ], ) + # Update the model list and parameters when user click outside the button or field + pretrained_model_name_or_path.change( + set_model_list, + inputs=[ + model_list, + pretrained_model_name_or_path, + v2, + v_parameterization, + ], + outputs=[ + model_list, + v2, + v_parameterization, + ], + ) return ( pretrained_model_name_or_path, v2,