Fix issue with msgbox

This commit is contained in:
bmaltais 2023-03-25 12:39:02 -04:00
parent d0ffba6650
commit 7e7a8b6aab

View File

@ -1,4 +1,5 @@
from tkinter import filedialog, Tk from tkinter import filedialog, Tk
from easygui import msgbox
import os import os
import gradio as gr import gradio as gr
import easygui import easygui
@ -60,28 +61,20 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
def update_my_data(my_data): def update_my_data(my_data):
# Update optimizer based on use_8bit_adam flag # Update the optimizer based on the use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False) use_8bit_adam = my_data.get('use_8bit_adam', False)
if use_8bit_adam: my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW')
my_data['optimizer'] = 'AdamW8bit'
elif 'optimizer' not in my_data:
my_data['optimizer'] = 'AdamW'
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
model_list = my_data.get('model_list', []) model_list = my_data.get('model_list', [])
pretrained_model_name_or_path = my_data.get( pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
'pretrained_model_name_or_path', '' if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
)
if (
not model_list
or pretrained_model_name_or_path not in ALL_PRESET_MODELS
):
my_data['model_list'] = 'custom' my_data['model_list'] = 'custom'
# Convert epoch and save_every_n_epochs values to int if they are strings # Convert epoch and save_every_n_epochs values to int if they are strings
for key in ['epoch', 'save_every_n_epochs']: for key in ['epoch', 'save_every_n_epochs']:
value = my_data.get(key, -1) value = my_data.get(key, -1)
if isinstance(value, str) and value: if isinstance(value, str) and value.isdigit():
my_data[key] = int(value) my_data[key] = int(value)
elif not value: elif not value:
my_data[key] = -1 my_data[key] = -1
@ -89,13 +82,19 @@ def update_my_data(my_data):
# Update LoRA_type if it is set to LoCon # Update LoRA_type if it is set to LoCon
if my_data.get('LoRA_type', 'Standard') == 'LoCon': if my_data.get('LoRA_type', 'Standard') == 'LoCon':
my_data['LoRA_type'] = 'LyCORIS/LoCon' my_data['LoRA_type'] = 'LyCORIS/LoCon'
# Update model save choices due to changes for LoRA and TI training # Update model save choices due to changes for LoRA and TI training
if (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) and my_data.get('save_model_as') not in ['safetensors', 'ckpt']: if (
(my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
):
message = (
'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
)
if my_data.get('LoRA_type'): if my_data.get('LoRA_type'):
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to LoRA') print(message.format('LoRA'))
if my_data.get('num_vectors_per_token'): if my_data.get('num_vectors_per_token'):
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to TI') print(message.format('TI'))
my_data['save_model_as'] = 'safetensors' my_data['save_model_as'] = 'safetensors'
return my_data return my_data