diff --git a/library/common_gui.py b/library/common_gui.py index ec83c22..8ff6a27 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -1,4 +1,5 @@ from tkinter import filedialog, Tk +from easygui import msgbox import os import gradio as gr import easygui @@ -60,28 +61,20 @@ def check_if_model_exist(output_name, output_dir, save_model_as): def update_my_data(my_data): - # Update optimizer based on use_8bit_adam flag + # Update the optimizer based on the use_8bit_adam flag use_8bit_adam = my_data.get('use_8bit_adam', False) - if use_8bit_adam: - my_data['optimizer'] = 'AdamW8bit' - elif 'optimizer' not in my_data: - my_data['optimizer'] = 'AdamW' + my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW') # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model model_list = my_data.get('model_list', []) - pretrained_model_name_or_path = my_data.get( - 'pretrained_model_name_or_path', '' - ) - if ( - not model_list - or pretrained_model_name_or_path not in ALL_PRESET_MODELS - ): + pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '') + if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS: my_data['model_list'] = 'custom' # Convert epoch and save_every_n_epochs values to int if they are strings for key in ['epoch', 'save_every_n_epochs']: value = my_data.get(key, -1) - if isinstance(value, str) and value: + if isinstance(value, str) and value.isdigit(): my_data[key] = int(value) elif not value: my_data[key] = -1 @@ -89,13 +82,19 @@ def update_my_data(my_data): # Update LoRA_type if it is set to LoCon if my_data.get('LoRA_type', 'Standard') == 'LoCon': my_data['LoRA_type'] = 'LyCORIS/LoCon' - + # Update model save choices due to changes for LoRA and TI training - if (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) and my_data.get('save_model_as') not in ['safetensors', 'ckpt']: + if ( + (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) + and my_data.get('save_model_as') not in ['safetensors', 'ckpt'] + ): + message = ( + 'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}' + ) if my_data.get('LoRA_type'): - print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to LoRA') + print(message.format('LoRA')) if my_data.get('num_vectors_per_token'): - print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to TI') + print(message.format('TI')) my_data['save_model_as'] = 'safetensors' return my_data