Fix issue with msgbox
This commit is contained in:
parent
d0ffba6650
commit
7e7a8b6aab
@ -1,4 +1,5 @@
|
|||||||
from tkinter import filedialog, Tk
|
from tkinter import filedialog, Tk
|
||||||
|
from easygui import msgbox
|
||||||
import os
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import easygui
|
import easygui
|
||||||
@ -60,28 +61,20 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
|
|||||||
|
|
||||||
|
|
||||||
def update_my_data(my_data):
|
def update_my_data(my_data):
|
||||||
# Update optimizer based on use_8bit_adam flag
|
# Update the optimizer based on the use_8bit_adam flag
|
||||||
use_8bit_adam = my_data.get('use_8bit_adam', False)
|
use_8bit_adam = my_data.get('use_8bit_adam', False)
|
||||||
if use_8bit_adam:
|
my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW')
|
||||||
my_data['optimizer'] = 'AdamW8bit'
|
|
||||||
elif 'optimizer' not in my_data:
|
|
||||||
my_data['optimizer'] = 'AdamW'
|
|
||||||
|
|
||||||
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
|
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
|
||||||
model_list = my_data.get('model_list', [])
|
model_list = my_data.get('model_list', [])
|
||||||
pretrained_model_name_or_path = my_data.get(
|
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
|
||||||
'pretrained_model_name_or_path', ''
|
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
|
||||||
)
|
|
||||||
if (
|
|
||||||
not model_list
|
|
||||||
or pretrained_model_name_or_path not in ALL_PRESET_MODELS
|
|
||||||
):
|
|
||||||
my_data['model_list'] = 'custom'
|
my_data['model_list'] = 'custom'
|
||||||
|
|
||||||
# Convert epoch and save_every_n_epochs values to int if they are strings
|
# Convert epoch and save_every_n_epochs values to int if they are strings
|
||||||
for key in ['epoch', 'save_every_n_epochs']:
|
for key in ['epoch', 'save_every_n_epochs']:
|
||||||
value = my_data.get(key, -1)
|
value = my_data.get(key, -1)
|
||||||
if isinstance(value, str) and value:
|
if isinstance(value, str) and value.isdigit():
|
||||||
my_data[key] = int(value)
|
my_data[key] = int(value)
|
||||||
elif not value:
|
elif not value:
|
||||||
my_data[key] = -1
|
my_data[key] = -1
|
||||||
@ -89,13 +82,19 @@ def update_my_data(my_data):
|
|||||||
# Update LoRA_type if it is set to LoCon
|
# Update LoRA_type if it is set to LoCon
|
||||||
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
|
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
|
||||||
my_data['LoRA_type'] = 'LyCORIS/LoCon'
|
my_data['LoRA_type'] = 'LyCORIS/LoCon'
|
||||||
|
|
||||||
# Update model save choices due to changes for LoRA and TI training
|
# Update model save choices due to changes for LoRA and TI training
|
||||||
if (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) and my_data.get('save_model_as') not in ['safetensors', 'ckpt']:
|
if (
|
||||||
|
(my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
|
||||||
|
and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
|
||||||
|
):
|
||||||
|
message = (
|
||||||
|
'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
|
||||||
|
)
|
||||||
if my_data.get('LoRA_type'):
|
if my_data.get('LoRA_type'):
|
||||||
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to LoRA')
|
print(message.format('LoRA'))
|
||||||
if my_data.get('num_vectors_per_token'):
|
if my_data.get('num_vectors_per_token'):
|
||||||
print('Updating save_model_as to safetensors because the current value in config file is no longer applicable to TI')
|
print(message.format('TI'))
|
||||||
my_data['save_model_as'] = 'safetensors'
|
my_data['save_model_as'] = 'safetensors'
|
||||||
|
|
||||||
return my_data
|
return my_data
|
||||||
|
Loading…
Reference in New Issue
Block a user