Fix issues with code detected while testing

This commit is contained in:
bmaltais 2023-03-08 21:16:54 -05:00
parent 35c1d42570
commit 2eb7b3bdc3
4 changed files with 24 additions and 11 deletions

View File

@ -224,19 +224,19 @@ def open_config_file(
if not file_path == '' and not file_path == None: if not file_path == '' and not file_path == None:
# load variables from JSON file # load variables from JSON file
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
my_data_db = json.load(f) my_data = json.load(f)
print('Loading config...') print('Loading config...')
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
my_data = update_my_data(my_data) my_data = update_my_data(my_data)
else: else:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
my_data_db = {} my_data = {}
values = [file_path] values = [file_path]
for key, value in parameters: for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']: if not key in ['file_path']:
values.append(my_data_db.get(key, value)) values.append(my_data.get(key, value))
return tuple(values) return tuple(values)

View File

@ -50,6 +50,15 @@ def update_my_data(my_data):
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS: if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
my_data['model_list'] = 'custom' my_data['model_list'] = 'custom'
# Fix old config files that contain epoch as str instead of int
for key in ['epoch', 'save_every_n_epochs']:
value = my_data.get(key, -1)
if type(value) == str:
if value != '':
my_data[key] = int(value)
else:
my_data[key] = -1
return my_data return my_data
@ -526,8 +535,8 @@ def gradio_training(
value=1, value=1,
step=1, step=1,
) )
epoch = gr.Textbox(label='Epoch', value=1) epoch = gr.Number(label='Epoch', value=1, precision=0)
save_every_n_epochs = gr.Textbox(label='Save every N epochs', value=1) save_every_n_epochs = gr.Number(label='Save every N epochs', value=1, precision=0)
caption_extension = gr.Textbox( caption_extension = gr.Textbox(
label='Caption Extension', label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption', placeholder='(Optional) Extension for caption files. default: .caption',
@ -634,8 +643,8 @@ def run_cmd_training(**kwargs):
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"' f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
if kwargs.get('max_train_steps') if kwargs.get('max_train_steps')
else '', else '',
f' --save_every_n_epochs="{kwargs.get("save_every_n_epochs", "")}"' f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"'
if kwargs.get('save_every_n_epochs') if int(kwargs.get('save_every_n_epochs'))
else '', else '',
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"' f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
if kwargs.get('mixed_precision') if kwargs.get('mixed_precision')
@ -643,7 +652,9 @@ def run_cmd_training(**kwargs):
f' --save_precision="{kwargs.get("save_precision", "")}"' f' --save_precision="{kwargs.get("save_precision", "")}"'
if kwargs.get('save_precision') if kwargs.get('save_precision')
else '', else '',
f' --seed="{kwargs.get("seed", "")}"' if kwargs.get('seed') else '', f' --seed="{kwargs.get("seed", "")}"'
if kwargs.get('seed') != ""
else '',
f' --caption_extension="{kwargs.get("caption_extension", "")}"' f' --caption_extension="{kwargs.get("caption_extension", "")}"'
if kwargs.get('caption_extension') if kwargs.get('caption_extension')
else '', else '',

View File

@ -232,19 +232,19 @@ def open_configuration(
if not file_path == '' and not file_path == None: if not file_path == '' and not file_path == None:
# load variables from JSON file # load variables from JSON file
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
my_data_db = json.load(f) my_data = json.load(f)
print('Loading config...') print('Loading config...')
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
my_data = update_my_data(my_data) my_data = update_my_data(my_data)
else: else:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
my_data_db = {} my_data = {}
values = [file_path] values = [file_path]
for key, value in parameters: for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']: if not key in ['file_path']:
values.append(my_data_db.get(key, value)) values.append(my_data.get(key, value))
return tuple(values) return tuple(values)

View File

@ -172,6 +172,8 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
prompt_replacement = None
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template: if use_template:
print("use template for training captions. is object: {args.use_object_template}") print("use template for training captions. is object: {args.use_object_template}")