Improve custom preset handling

This commit is contained in:
bmaltais 2023-03-05 21:10:24 -05:00
parent a57cdd5d42
commit cc7aee2301
2 changed files with 70 additions and 12 deletions

View File

@ -176,6 +176,9 @@ This will store your a backup file with your current locally installed pip packa
## Change History
* 2023/03/05 (v21.1.5):
- Add replace underscore with space option to WD14 captioning. Thanks @sALTaccount!
- Improve how custom preset is set and handles. Still not perfect but better.
* 2023/03/05 (v21.1.4):
- Removing legacy and confusing use 8bit adam chackbox. It is now configured using the Optimiser drop down list. It will be set properly based on legacy config files.
* 2023/03/04 (v21.1.3):

View File

@ -9,18 +9,36 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
# define a list of substrings to search for
preset_models = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5',
]
def update_my_data(my_data):
if my_data.get('use_8bit_adam', False) == True:
my_data['optimizer'] = 'AdamW8bit'
# my_data['use_8bit_adam'] = False
if my_data.get('optimizer', 'missing') == 'missing' and my_data.get('use_8bit_adam', False) == False:
if (
my_data.get('optimizer', 'missing') == 'missing'
and my_data.get('use_8bit_adam', False) == False
):
my_data['optimizer'] = 'AdamW'
if my_data.get('model_list', 'custom') == []:
print('Old config with empty model list. Setting to custom...')
my_data['model_list'] = 'custom'
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
if not my_data.get('pretrained_model_name_or_path', '') in preset_models:
my_data['model_list'] = 'custom'
return my_data
@ -293,8 +311,7 @@ def set_pretrained_model_name_or_path_input(
print('SD v2 model detected. Setting --v2 parameter')
v2 = True
v_parameterization = False
return model_list, v2, v_parameterization
pretrained_model_name_or_path = str(model_list)
# define a list of substrings to search for v-objective
substrings_v_parameterization = [
@ -309,8 +326,7 @@ def set_pretrained_model_name_or_path_input(
)
v2 = True
v_parameterization = True
return model_list, v2, v_parameterization
pretrained_model_name_or_path = str(model_list)
# define a list of substrings to v1.x
substrings_v1_model = [
@ -321,8 +337,7 @@ def set_pretrained_model_name_or_path_input(
if str(model_list) in substrings_v1_model:
v2 = False
v_parameterization = False
return model_list, v2, v_parameterization
pretrained_model_name_or_path = str(model_list)
if model_list == 'custom':
if (
@ -334,11 +349,26 @@ def set_pretrained_model_name_or_path_input(
pretrained_model_name_or_path = ''
v2 = False
v_parameterization = False
return pretrained_model_name_or_path, v2, v_parameterization
return model_list, pretrained_model_name_or_path, v2, v_parameterization
###
### Gradio common GUI section
###
def set_model_list(
model_list,
pretrained_model_name_or_path,
v2,
v_parameterization,
):
if not pretrained_model_name_or_path in preset_models:
model_list = 'custom'
else:
model_list = pretrained_model_name_or_path
return model_list, v2, v_parameterization
###
### Gradio common GUI section
###
def gradio_config():
@ -362,6 +392,15 @@ def gradio_config():
)
def get_pretrained_model_name_or_path_file(
model_list, pretrained_model_name_or_path
):
pretrained_model_name_or_path = get_any_file_path(
pretrained_model_name_or_path
)
set_model_list(model_list, pretrained_model_name_or_path)
def gradio_source_model():
with gr.Tab('Source model'):
# Define the input elements
@ -428,11 +467,27 @@ def gradio_source_model():
v_parameterization,
],
outputs=[
model_list,
pretrained_model_name_or_path,
v2,
v_parameterization,
],
)
# Update the model list and parameters when user click outside the button or field
pretrained_model_name_or_path.change(
set_model_list,
inputs=[
model_list,
pretrained_model_name_or_path,
v2,
v_parameterization,
],
outputs=[
model_list,
v2,
v_parameterization,
],
)
return (
pretrained_model_name_or_path,
v2,