Remove bad save_model_as choices for LoRA and TI

This commit is contained in:
bmaltais 2023-03-25 09:08:02 -04:00
parent a9aafff259
commit d0ffba6650
4 changed files with 27 additions and 45 deletions

View File

@ -218,6 +218,7 @@ This will store your a backup file with your current locally installed pip packa
Let me know how this work. From the look of it it appear to be well tought out. I modified a few things to make it fit better with the rest of the code in the repo.
- Fix for issue https://github.com/bmaltais/kohya_ss/issues/433 by implementing default of 0.
- Removed non applicable save_model_as choices for LoRA and TI.
* 2023/03/24 (v21.3.3)
- Add support for custom user gui files. THey will be created at installation time or when upgrading is missing. You will see two files in the root of the folder. One named `gui-user.bat` and the other `gui-user.ps1`. Edit the file based on your prefered terminal. Simply add the parameters you want to pass the gui in there and execute it to start the gui with them. Enjoy!
* 2023/03/23 (v21.3.2)

View File

@ -89,44 +89,18 @@ def update_my_data(my_data):
# Update LoRA_type if it is set to LoCon
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
my_data['LoRA_type'] = 'LyCORIS/LoCon'
# Update model save choices due to changes for LoRA and TI training
if (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) and my_data.get('save_model_as') not in ['safetensors', 'ckpt']:
if my_data.get('LoRA_type'):
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to LoRA')
if my_data.get('num_vectors_per_token'):
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to TI')
my_data['save_model_as'] = 'safetensors'
return my_data
# def update_my_data(my_data):
# if my_data.get('use_8bit_adam', False) == True:
# my_data['optimizer'] = 'AdamW8bit'
# # my_data['use_8bit_adam'] = False
# if (
# my_data.get('optimizer', 'missing') == 'missing'
# and my_data.get('use_8bit_adam', False) == False
# ):
# my_data['optimizer'] = 'AdamW'
# if my_data.get('model_list', 'custom') == []:
# print('Old config with empty model list. Setting to custom...')
# my_data['model_list'] = 'custom'
# # If Pretrained model name or path is not one of the preset models then set the preset_model to custom
# if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
# my_data['model_list'] = 'custom'
# # Fix old config files that contain epoch as str instead of int
# for key in ['epoch', 'save_every_n_epochs']:
# value = my_data.get(key, -1)
# if type(value) == str:
# if value != '':
# my_data[key] = int(value)
# else:
# my_data[key] = -1
# if my_data.get('LoRA_type', 'Standard') == 'LoCon':
# my_data['LoRA_type'] = 'LyCORIS/LoCon'
# return my_data
def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name)
@ -604,7 +578,13 @@ def get_pretrained_model_name_or_path_file(
set_model_list(model_list, pretrained_model_name_or_path)
def gradio_source_model():
def gradio_source_model(save_model_as_choices = [
'same as source model',
'ckpt',
'diffusers',
'diffusers_safetensors',
'safetensors',
]):
with gr.Tab('Source model'):
# Define the input elements
with gr.Row():
@ -646,13 +626,7 @@ def gradio_source_model():
)
save_model_as = gr.Dropdown(
label='Save trained model as',
choices=[
'same as source model',
'ckpt',
'diffusers',
'diffusers_safetensors',
'safetensors',
],
choices=save_model_as_choices,
value='safetensors',
)

View File

@ -257,7 +257,8 @@ def open_configuration(
with open(file_path, 'r') as f:
my_data = json.load(f)
print('Loading config...')
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
# Update values to fix deprecated use_8bit_adam checkbox, set appropriate optimizer if it is set to True, etc.
my_data = update_my_data(my_data)
else:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
@ -648,7 +649,10 @@ def lora_tab(
v_parameterization,
save_model_as,
model_list,
) = gradio_source_model()
) = gradio_source_model(save_model_as_choices = [
'ckpt',
'safetensors',
])
with gr.Tab('Folders'):
with gr.Row():

View File

@ -570,7 +570,10 @@ def ti_tab(
v_parameterization,
save_model_as,
model_list,
) = gradio_source_model()
) = gradio_source_model(save_model_as_choices = [
'ckpt',
'safetensors',
])
with gr.Tab('Folders'):
with gr.Row():