Fix issues with code detected while testing
This commit is contained in:
parent
35c1d42570
commit
2eb7b3bdc3
@ -224,19 +224,19 @@ def open_config_file(
|
|||||||
if not file_path == '' and not file_path == None:
|
if not file_path == '' and not file_path == None:
|
||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data_db = json.load(f)
|
my_data = json.load(f)
|
||||||
print('Loading config...')
|
print('Loading config...')
|
||||||
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
||||||
my_data = update_my_data(my_data)
|
my_data = update_my_data(my_data)
|
||||||
else:
|
else:
|
||||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
my_data_db = {}
|
my_data = {}
|
||||||
|
|
||||||
values = [file_path]
|
values = [file_path]
|
||||||
for key, value in parameters:
|
for key, value in parameters:
|
||||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||||
if not key in ['file_path']:
|
if not key in ['file_path']:
|
||||||
values.append(my_data_db.get(key, value))
|
values.append(my_data.get(key, value))
|
||||||
return tuple(values)
|
return tuple(values)
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,6 +49,15 @@ def update_my_data(my_data):
|
|||||||
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
||||||
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
|
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
|
||||||
my_data['model_list'] = 'custom'
|
my_data['model_list'] = 'custom'
|
||||||
|
|
||||||
|
# Fix old config files that contain epoch as str instead of int
|
||||||
|
for key in ['epoch', 'save_every_n_epochs']:
|
||||||
|
value = my_data.get(key, -1)
|
||||||
|
if type(value) == str:
|
||||||
|
if value != '':
|
||||||
|
my_data[key] = int(value)
|
||||||
|
else:
|
||||||
|
my_data[key] = -1
|
||||||
|
|
||||||
return my_data
|
return my_data
|
||||||
|
|
||||||
@ -526,8 +535,8 @@ def gradio_training(
|
|||||||
value=1,
|
value=1,
|
||||||
step=1,
|
step=1,
|
||||||
)
|
)
|
||||||
epoch = gr.Textbox(label='Epoch', value=1)
|
epoch = gr.Number(label='Epoch', value=1, precision=0)
|
||||||
save_every_n_epochs = gr.Textbox(label='Save every N epochs', value=1)
|
save_every_n_epochs = gr.Number(label='Save every N epochs', value=1, precision=0)
|
||||||
caption_extension = gr.Textbox(
|
caption_extension = gr.Textbox(
|
||||||
label='Caption Extension',
|
label='Caption Extension',
|
||||||
placeholder='(Optional) Extension for caption files. default: .caption',
|
placeholder='(Optional) Extension for caption files. default: .caption',
|
||||||
@ -634,8 +643,8 @@ def run_cmd_training(**kwargs):
|
|||||||
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
|
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
|
||||||
if kwargs.get('max_train_steps')
|
if kwargs.get('max_train_steps')
|
||||||
else '',
|
else '',
|
||||||
f' --save_every_n_epochs="{kwargs.get("save_every_n_epochs", "")}"'
|
f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"'
|
||||||
if kwargs.get('save_every_n_epochs')
|
if int(kwargs.get('save_every_n_epochs'))
|
||||||
else '',
|
else '',
|
||||||
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
|
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
|
||||||
if kwargs.get('mixed_precision')
|
if kwargs.get('mixed_precision')
|
||||||
@ -643,7 +652,9 @@ def run_cmd_training(**kwargs):
|
|||||||
f' --save_precision="{kwargs.get("save_precision", "")}"'
|
f' --save_precision="{kwargs.get("save_precision", "")}"'
|
||||||
if kwargs.get('save_precision')
|
if kwargs.get('save_precision')
|
||||||
else '',
|
else '',
|
||||||
f' --seed="{kwargs.get("seed", "")}"' if kwargs.get('seed') else '',
|
f' --seed="{kwargs.get("seed", "")}"'
|
||||||
|
if kwargs.get('seed') != ""
|
||||||
|
else '',
|
||||||
f' --caption_extension="{kwargs.get("caption_extension", "")}"'
|
f' --caption_extension="{kwargs.get("caption_extension", "")}"'
|
||||||
if kwargs.get('caption_extension')
|
if kwargs.get('caption_extension')
|
||||||
else '',
|
else '',
|
||||||
|
@ -232,19 +232,19 @@ def open_configuration(
|
|||||||
if not file_path == '' and not file_path == None:
|
if not file_path == '' and not file_path == None:
|
||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data_db = json.load(f)
|
my_data = json.load(f)
|
||||||
print('Loading config...')
|
print('Loading config...')
|
||||||
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
||||||
my_data = update_my_data(my_data)
|
my_data = update_my_data(my_data)
|
||||||
else:
|
else:
|
||||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
my_data_db = {}
|
my_data = {}
|
||||||
|
|
||||||
values = [file_path]
|
values = [file_path]
|
||||||
for key, value in parameters:
|
for key, value in parameters:
|
||||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||||
if not key in ['file_path']:
|
if not key in ['file_path']:
|
||||||
values.append(my_data_db.get(key, value))
|
values.append(my_data.get(key, value))
|
||||||
return tuple(values)
|
return tuple(values)
|
||||||
|
|
||||||
|
|
||||||
|
@ -172,6 +172,8 @@ def train(args):
|
|||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
|
prompt_replacement = None
|
||||||
|
|
||||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||||
if use_template:
|
if use_template:
|
||||||
print("use template for training captions. is object: {args.use_object_template}")
|
print("use template for training captions. is object: {args.use_object_template}")
|
||||||
|
Loading…
Reference in New Issue
Block a user