Improve custom preset handling
This commit is contained in:
parent
a57cdd5d42
commit
cc7aee2301
@ -176,6 +176,9 @@ This will store your a backup file with your current locally installed pip packa
|
||||
|
||||
## Change History
|
||||
|
||||
* 2023/03/05 (v21.1.5):
|
||||
- Add replace underscore with space option to WD14 captioning. Thanks @sALTaccount!
|
||||
- Improve how custom preset is set and handles. Still not perfect but better.
|
||||
* 2023/03/05 (v21.1.4):
|
||||
- Removing legacy and confusing use 8bit adam chackbox. It is now configured using the Optimiser drop down list. It will be set properly based on legacy config files.
|
||||
* 2023/03/04 (v21.1.3):
|
||||
|
@ -9,18 +9,36 @@ refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
|
||||
# define a list of substrings to search for
|
||||
preset_models = [
|
||||
'stabilityai/stable-diffusion-2-1-base',
|
||||
'stabilityai/stable-diffusion-2-base',
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
'stabilityai/stable-diffusion-2',
|
||||
'CompVis/stable-diffusion-v1-4',
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
]
|
||||
|
||||
|
||||
def update_my_data(my_data):
|
||||
if my_data.get('use_8bit_adam', False) == True:
|
||||
my_data['optimizer'] = 'AdamW8bit'
|
||||
# my_data['use_8bit_adam'] = False
|
||||
|
||||
if my_data.get('optimizer', 'missing') == 'missing' and my_data.get('use_8bit_adam', False) == False:
|
||||
|
||||
if (
|
||||
my_data.get('optimizer', 'missing') == 'missing'
|
||||
and my_data.get('use_8bit_adam', False) == False
|
||||
):
|
||||
my_data['optimizer'] = 'AdamW'
|
||||
|
||||
if my_data.get('model_list', 'custom') == []:
|
||||
print('Old config with empty model list. Setting to custom...')
|
||||
my_data['model_list'] = 'custom'
|
||||
|
||||
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
||||
if not my_data.get('pretrained_model_name_or_path', '') in preset_models:
|
||||
my_data['model_list'] = 'custom'
|
||||
|
||||
return my_data
|
||||
|
||||
|
||||
@ -293,8 +311,7 @@ def set_pretrained_model_name_or_path_input(
|
||||
print('SD v2 model detected. Setting --v2 parameter')
|
||||
v2 = True
|
||||
v_parameterization = False
|
||||
|
||||
return model_list, v2, v_parameterization
|
||||
pretrained_model_name_or_path = str(model_list)
|
||||
|
||||
# define a list of substrings to search for v-objective
|
||||
substrings_v_parameterization = [
|
||||
@ -309,8 +326,7 @@ def set_pretrained_model_name_or_path_input(
|
||||
)
|
||||
v2 = True
|
||||
v_parameterization = True
|
||||
|
||||
return model_list, v2, v_parameterization
|
||||
pretrained_model_name_or_path = str(model_list)
|
||||
|
||||
# define a list of substrings to v1.x
|
||||
substrings_v1_model = [
|
||||
@ -321,8 +337,7 @@ def set_pretrained_model_name_or_path_input(
|
||||
if str(model_list) in substrings_v1_model:
|
||||
v2 = False
|
||||
v_parameterization = False
|
||||
|
||||
return model_list, v2, v_parameterization
|
||||
pretrained_model_name_or_path = str(model_list)
|
||||
|
||||
if model_list == 'custom':
|
||||
if (
|
||||
@ -334,11 +349,26 @@ def set_pretrained_model_name_or_path_input(
|
||||
pretrained_model_name_or_path = ''
|
||||
v2 = False
|
||||
v_parameterization = False
|
||||
return pretrained_model_name_or_path, v2, v_parameterization
|
||||
return model_list, pretrained_model_name_or_path, v2, v_parameterization
|
||||
|
||||
###
|
||||
### Gradio common GUI section
|
||||
###
|
||||
def set_model_list(
|
||||
model_list,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
):
|
||||
|
||||
if not pretrained_model_name_or_path in preset_models:
|
||||
model_list = 'custom'
|
||||
else:
|
||||
model_list = pretrained_model_name_or_path
|
||||
|
||||
return model_list, v2, v_parameterization
|
||||
|
||||
|
||||
###
|
||||
### Gradio common GUI section
|
||||
###
|
||||
|
||||
|
||||
def gradio_config():
|
||||
@ -362,6 +392,15 @@ def gradio_config():
|
||||
)
|
||||
|
||||
|
||||
def get_pretrained_model_name_or_path_file(
|
||||
model_list, pretrained_model_name_or_path
|
||||
):
|
||||
pretrained_model_name_or_path = get_any_file_path(
|
||||
pretrained_model_name_or_path
|
||||
)
|
||||
set_model_list(model_list, pretrained_model_name_or_path)
|
||||
|
||||
|
||||
def gradio_source_model():
|
||||
with gr.Tab('Source model'):
|
||||
# Define the input elements
|
||||
@ -428,11 +467,27 @@ def gradio_source_model():
|
||||
v_parameterization,
|
||||
],
|
||||
outputs=[
|
||||
model_list,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
],
|
||||
)
|
||||
# Update the model list and parameters when user click outside the button or field
|
||||
pretrained_model_name_or_path.change(
|
||||
set_model_list,
|
||||
inputs=[
|
||||
model_list,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
],
|
||||
outputs=[
|
||||
model_list,
|
||||
v2,
|
||||
v_parameterization,
|
||||
],
|
||||
)
|
||||
return (
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
|
Loading…
Reference in New Issue
Block a user