Fix issue with msgbox
This commit is contained in:
parent
d0ffba6650
commit
7e7a8b6aab
@ -1,4 +1,5 @@
|
||||
from tkinter import filedialog, Tk
|
||||
from easygui import msgbox
|
||||
import os
|
||||
import gradio as gr
|
||||
import easygui
|
||||
@ -60,28 +61,20 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
|
||||
|
||||
|
||||
def update_my_data(my_data):
|
||||
# Update optimizer based on use_8bit_adam flag
|
||||
# Update the optimizer based on the use_8bit_adam flag
|
||||
use_8bit_adam = my_data.get('use_8bit_adam', False)
|
||||
if use_8bit_adam:
|
||||
my_data['optimizer'] = 'AdamW8bit'
|
||||
elif 'optimizer' not in my_data:
|
||||
my_data['optimizer'] = 'AdamW'
|
||||
my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW')
|
||||
|
||||
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
|
||||
model_list = my_data.get('model_list', [])
|
||||
pretrained_model_name_or_path = my_data.get(
|
||||
'pretrained_model_name_or_path', ''
|
||||
)
|
||||
if (
|
||||
not model_list
|
||||
or pretrained_model_name_or_path not in ALL_PRESET_MODELS
|
||||
):
|
||||
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
|
||||
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
|
||||
my_data['model_list'] = 'custom'
|
||||
|
||||
# Convert epoch and save_every_n_epochs values to int if they are strings
|
||||
for key in ['epoch', 'save_every_n_epochs']:
|
||||
value = my_data.get(key, -1)
|
||||
if isinstance(value, str) and value:
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
my_data[key] = int(value)
|
||||
elif not value:
|
||||
my_data[key] = -1
|
||||
@ -89,13 +82,19 @@ def update_my_data(my_data):
|
||||
# Update LoRA_type if it is set to LoCon
|
||||
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
|
||||
my_data['LoRA_type'] = 'LyCORIS/LoCon'
|
||||
|
||||
|
||||
# Update model save choices due to changes for LoRA and TI training
|
||||
if (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) and my_data.get('save_model_as') not in ['safetensors', 'ckpt']:
|
||||
if (
|
||||
(my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
|
||||
and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
|
||||
):
|
||||
message = (
|
||||
'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
|
||||
)
|
||||
if my_data.get('LoRA_type'):
|
||||
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to LoRA')
|
||||
print(message.format('LoRA'))
|
||||
if my_data.get('num_vectors_per_token'):
|
||||
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to TI')
|
||||
print(message.format('TI'))
|
||||
my_data['save_model_as'] = 'safetensors'
|
||||
|
||||
return my_data
|
||||
|
Loading…
Reference in New Issue
Block a user