Improve custom preset handling
This commit is contained in:
parent
a57cdd5d42
commit
cc7aee2301
@ -176,6 +176,9 @@ This will store your a backup file with your current locally installed pip packa
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
|
* 2023/03/05 (v21.1.5):
|
||||||
|
- Add replace underscore with space option to WD14 captioning. Thanks @sALTaccount!
|
||||||
|
- Improve how custom preset is set and handles. Still not perfect but better.
|
||||||
* 2023/03/05 (v21.1.4):
|
* 2023/03/05 (v21.1.4):
|
||||||
- Removing legacy and confusing use 8bit adam chackbox. It is now configured using the Optimiser drop down list. It will be set properly based on legacy config files.
|
- Removing legacy and confusing use 8bit adam chackbox. It is now configured using the Optimiser drop down list. It will be set properly based on legacy config files.
|
||||||
* 2023/03/04 (v21.1.3):
|
* 2023/03/04 (v21.1.3):
|
||||||
|
@ -9,18 +9,36 @@ refresh_symbol = '\U0001f504' # 🔄
|
|||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
|
# define a list of substrings to search for
|
||||||
|
preset_models = [
|
||||||
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
|
'stabilityai/stable-diffusion-2-base',
|
||||||
|
'stabilityai/stable-diffusion-2-1',
|
||||||
|
'stabilityai/stable-diffusion-2',
|
||||||
|
'CompVis/stable-diffusion-v1-4',
|
||||||
|
'runwayml/stable-diffusion-v1-5',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def update_my_data(my_data):
|
def update_my_data(my_data):
|
||||||
if my_data.get('use_8bit_adam', False) == True:
|
if my_data.get('use_8bit_adam', False) == True:
|
||||||
my_data['optimizer'] = 'AdamW8bit'
|
my_data['optimizer'] = 'AdamW8bit'
|
||||||
# my_data['use_8bit_adam'] = False
|
# my_data['use_8bit_adam'] = False
|
||||||
|
|
||||||
if my_data.get('optimizer', 'missing') == 'missing' and my_data.get('use_8bit_adam', False) == False:
|
if (
|
||||||
|
my_data.get('optimizer', 'missing') == 'missing'
|
||||||
|
and my_data.get('use_8bit_adam', False) == False
|
||||||
|
):
|
||||||
my_data['optimizer'] = 'AdamW'
|
my_data['optimizer'] = 'AdamW'
|
||||||
|
|
||||||
if my_data.get('model_list', 'custom') == []:
|
if my_data.get('model_list', 'custom') == []:
|
||||||
print('Old config with empty model list. Setting to custom...')
|
print('Old config with empty model list. Setting to custom...')
|
||||||
my_data['model_list'] = 'custom'
|
my_data['model_list'] = 'custom'
|
||||||
|
|
||||||
|
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
||||||
|
if not my_data.get('pretrained_model_name_or_path', '') in preset_models:
|
||||||
|
my_data['model_list'] = 'custom'
|
||||||
|
|
||||||
return my_data
|
return my_data
|
||||||
|
|
||||||
|
|
||||||
@ -293,8 +311,7 @@ def set_pretrained_model_name_or_path_input(
|
|||||||
print('SD v2 model detected. Setting --v2 parameter')
|
print('SD v2 model detected. Setting --v2 parameter')
|
||||||
v2 = True
|
v2 = True
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
|
pretrained_model_name_or_path = str(model_list)
|
||||||
return model_list, v2, v_parameterization
|
|
||||||
|
|
||||||
# define a list of substrings to search for v-objective
|
# define a list of substrings to search for v-objective
|
||||||
substrings_v_parameterization = [
|
substrings_v_parameterization = [
|
||||||
@ -309,8 +326,7 @@ def set_pretrained_model_name_or_path_input(
|
|||||||
)
|
)
|
||||||
v2 = True
|
v2 = True
|
||||||
v_parameterization = True
|
v_parameterization = True
|
||||||
|
pretrained_model_name_or_path = str(model_list)
|
||||||
return model_list, v2, v_parameterization
|
|
||||||
|
|
||||||
# define a list of substrings to v1.x
|
# define a list of substrings to v1.x
|
||||||
substrings_v1_model = [
|
substrings_v1_model = [
|
||||||
@ -321,8 +337,7 @@ def set_pretrained_model_name_or_path_input(
|
|||||||
if str(model_list) in substrings_v1_model:
|
if str(model_list) in substrings_v1_model:
|
||||||
v2 = False
|
v2 = False
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
|
pretrained_model_name_or_path = str(model_list)
|
||||||
return model_list, v2, v_parameterization
|
|
||||||
|
|
||||||
if model_list == 'custom':
|
if model_list == 'custom':
|
||||||
if (
|
if (
|
||||||
@ -334,11 +349,26 @@ def set_pretrained_model_name_or_path_input(
|
|||||||
pretrained_model_name_or_path = ''
|
pretrained_model_name_or_path = ''
|
||||||
v2 = False
|
v2 = False
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
return pretrained_model_name_or_path, v2, v_parameterization
|
return model_list, pretrained_model_name_or_path, v2, v_parameterization
|
||||||
|
|
||||||
###
|
def set_model_list(
|
||||||
### Gradio common GUI section
|
model_list,
|
||||||
###
|
pretrained_model_name_or_path,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
):
|
||||||
|
|
||||||
|
if not pretrained_model_name_or_path in preset_models:
|
||||||
|
model_list = 'custom'
|
||||||
|
else:
|
||||||
|
model_list = pretrained_model_name_or_path
|
||||||
|
|
||||||
|
return model_list, v2, v_parameterization
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
|
### Gradio common GUI section
|
||||||
|
###
|
||||||
|
|
||||||
|
|
||||||
def gradio_config():
|
def gradio_config():
|
||||||
@ -362,6 +392,15 @@ def gradio_config():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretrained_model_name_or_path_file(
|
||||||
|
model_list, pretrained_model_name_or_path
|
||||||
|
):
|
||||||
|
pretrained_model_name_or_path = get_any_file_path(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
set_model_list(model_list, pretrained_model_name_or_path)
|
||||||
|
|
||||||
|
|
||||||
def gradio_source_model():
|
def gradio_source_model():
|
||||||
with gr.Tab('Source model'):
|
with gr.Tab('Source model'):
|
||||||
# Define the input elements
|
# Define the input elements
|
||||||
@ -428,11 +467,27 @@ def gradio_source_model():
|
|||||||
v_parameterization,
|
v_parameterization,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
model_list,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
# Update the model list and parameters when user click outside the button or field
|
||||||
|
pretrained_model_name_or_path.change(
|
||||||
|
set_model_list,
|
||||||
|
inputs=[
|
||||||
|
model_list,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
model_list,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
],
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
|
Loading…
Reference in New Issue
Block a user