Fix issues with code detected while testing
This commit is contained in:
parent
35c1d42570
commit
2eb7b3bdc3
@ -224,19 +224,19 @@ def open_config_file(
|
||||
if not file_path == '' and not file_path == None:
|
||||
# load variables from JSON file
|
||||
with open(file_path, 'r') as f:
|
||||
my_data_db = json.load(f)
|
||||
my_data = json.load(f)
|
||||
print('Loading config...')
|
||||
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
||||
my_data = update_my_data(my_data)
|
||||
else:
|
||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||
my_data_db = {}
|
||||
my_data = {}
|
||||
|
||||
values = [file_path]
|
||||
for key, value in parameters:
|
||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||
if not key in ['file_path']:
|
||||
values.append(my_data_db.get(key, value))
|
||||
values.append(my_data.get(key, value))
|
||||
return tuple(values)
|
||||
|
||||
|
||||
|
@ -49,6 +49,15 @@ def update_my_data(my_data):
|
||||
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
||||
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
|
||||
my_data['model_list'] = 'custom'
|
||||
|
||||
# Fix old config files that contain epoch as str instead of int
|
||||
for key in ['epoch', 'save_every_n_epochs']:
|
||||
value = my_data.get(key, -1)
|
||||
if type(value) == str:
|
||||
if value != '':
|
||||
my_data[key] = int(value)
|
||||
else:
|
||||
my_data[key] = -1
|
||||
|
||||
return my_data
|
||||
|
||||
@ -526,8 +535,8 @@ def gradio_training(
|
||||
value=1,
|
||||
step=1,
|
||||
)
|
||||
epoch = gr.Textbox(label='Epoch', value=1)
|
||||
save_every_n_epochs = gr.Textbox(label='Save every N epochs', value=1)
|
||||
epoch = gr.Number(label='Epoch', value=1, precision=0)
|
||||
save_every_n_epochs = gr.Number(label='Save every N epochs', value=1, precision=0)
|
||||
caption_extension = gr.Textbox(
|
||||
label='Caption Extension',
|
||||
placeholder='(Optional) Extension for caption files. default: .caption',
|
||||
@ -634,8 +643,8 @@ def run_cmd_training(**kwargs):
|
||||
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
|
||||
if kwargs.get('max_train_steps')
|
||||
else '',
|
||||
f' --save_every_n_epochs="{kwargs.get("save_every_n_epochs", "")}"'
|
||||
if kwargs.get('save_every_n_epochs')
|
||||
f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"'
|
||||
if int(kwargs.get('save_every_n_epochs'))
|
||||
else '',
|
||||
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
|
||||
if kwargs.get('mixed_precision')
|
||||
@ -643,7 +652,9 @@ def run_cmd_training(**kwargs):
|
||||
f' --save_precision="{kwargs.get("save_precision", "")}"'
|
||||
if kwargs.get('save_precision')
|
||||
else '',
|
||||
f' --seed="{kwargs.get("seed", "")}"' if kwargs.get('seed') else '',
|
||||
f' --seed="{kwargs.get("seed", "")}"'
|
||||
if kwargs.get('seed') != ""
|
||||
else '',
|
||||
f' --caption_extension="{kwargs.get("caption_extension", "")}"'
|
||||
if kwargs.get('caption_extension')
|
||||
else '',
|
||||
|
@ -232,19 +232,19 @@ def open_configuration(
|
||||
if not file_path == '' and not file_path == None:
|
||||
# load variables from JSON file
|
||||
with open(file_path, 'r') as f:
|
||||
my_data_db = json.load(f)
|
||||
my_data = json.load(f)
|
||||
print('Loading config...')
|
||||
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
||||
my_data = update_my_data(my_data)
|
||||
else:
|
||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||
my_data_db = {}
|
||||
my_data = {}
|
||||
|
||||
values = [file_path]
|
||||
for key, value in parameters:
|
||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||
if not key in ['file_path']:
|
||||
values.append(my_data_db.get(key, value))
|
||||
values.append(my_data.get(key, value))
|
||||
return tuple(values)
|
||||
|
||||
|
||||
|
@ -172,6 +172,8 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
prompt_replacement = None
|
||||
|
||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||
if use_template:
|
||||
print("use template for training captions. is object: {args.use_object_template}")
|
||||
|
Loading…
Reference in New Issue
Block a user