Fix issue with msgbox

This commit is contained in:
bmaltais 2023-03-25 12:39:02 -04:00
parent d0ffba6650
commit 7e7a8b6aab

View File

@ -1,4 +1,5 @@
from tkinter import filedialog, Tk
from easygui import msgbox
import os
import gradio as gr
import easygui
@ -60,28 +61,20 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
def update_my_data(my_data):
# Update optimizer based on use_8bit_adam flag
# Update the optimizer based on the use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False)
if use_8bit_adam:
my_data['optimizer'] = 'AdamW8bit'
elif 'optimizer' not in my_data:
my_data['optimizer'] = 'AdamW'
my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW')
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
model_list = my_data.get('model_list', [])
pretrained_model_name_or_path = my_data.get(
'pretrained_model_name_or_path', ''
)
if (
not model_list
or pretrained_model_name_or_path not in ALL_PRESET_MODELS
):
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
my_data['model_list'] = 'custom'
# Convert epoch and save_every_n_epochs values to int if they are strings
for key in ['epoch', 'save_every_n_epochs']:
value = my_data.get(key, -1)
if isinstance(value, str) and value:
if isinstance(value, str) and value.isdigit():
my_data[key] = int(value)
elif not value:
my_data[key] = -1
@ -89,13 +82,19 @@ def update_my_data(my_data):
# Update LoRA_type if it is set to LoCon
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
my_data['LoRA_type'] = 'LyCORIS/LoCon'
# Update model save choices due to changes for LoRA and TI training
if (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) and my_data.get('save_model_as') not in ['safetensors', 'ckpt']:
if (
(my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
):
message = (
'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
)
if my_data.get('LoRA_type'):
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to LoRA')
print(message.format('LoRA'))
if my_data.get('num_vectors_per_token'):
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to TI')
print(message.format('TI'))
my_data['save_model_as'] = 'safetensors'
return my_data