Fix basic captioning logic
This commit is contained in:
parent
7a94c523f5
commit
baf009d2b1
@ -189,6 +189,8 @@ This will store your a backup file with your current locally installed pip packa
|
||||
|
||||
## Change History
|
||||
|
||||
* 2023/03/16 (v21.2.5):
|
||||
- Fix basic captioning logic
|
||||
* 2023/03/12 (v21.2.4):
|
||||
- Fix issue with kohya locon not training the convolution layers
|
||||
- Update LyCORIS module version
|
||||
|
@ -6,35 +6,33 @@ import os
|
||||
|
||||
|
||||
def caption_images(
|
||||
caption_text_input,
|
||||
images_dir_input,
|
||||
overwrite_input,
|
||||
caption_file_ext,
|
||||
caption_text,
|
||||
images_dir,
|
||||
overwrite,
|
||||
caption_ext,
|
||||
prefix,
|
||||
postfix,
|
||||
find,
|
||||
replace,
|
||||
find_text,
|
||||
replace_text,
|
||||
):
|
||||
# Check for images_dir_input
|
||||
if images_dir_input == '':
|
||||
# Check for images_dir
|
||||
if not images_dir:
|
||||
msgbox('Image folder is missing...')
|
||||
return
|
||||
|
||||
if caption_file_ext == '':
|
||||
if not caption_ext:
|
||||
msgbox('Please provide an extension for the caption files.')
|
||||
return
|
||||
|
||||
if not caption_text_input == '':
|
||||
print(
|
||||
f'Captioning files in {images_dir_input} with {caption_text_input}...'
|
||||
)
|
||||
if caption_text:
|
||||
print(f'Captioning files in {images_dir} with {caption_text}...')
|
||||
run_cmd = f'python "tools/caption.py"'
|
||||
run_cmd += f' --caption_text="{caption_text_input}"'
|
||||
if overwrite_input:
|
||||
run_cmd += f' --caption_text="{caption_text}"'
|
||||
if overwrite:
|
||||
run_cmd += f' --overwrite'
|
||||
if caption_file_ext != '':
|
||||
run_cmd += f' --caption_file_ext="{caption_file_ext}"'
|
||||
run_cmd += f' "{images_dir_input}"'
|
||||
if caption_ext:
|
||||
run_cmd += f' --caption_file_ext="{caption_ext}"'
|
||||
run_cmd += f' "{images_dir}"'
|
||||
|
||||
print(run_cmd)
|
||||
|
||||
@ -44,24 +42,24 @@ def caption_images(
|
||||
else:
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
if overwrite_input:
|
||||
if not prefix == '' or not postfix == '':
|
||||
if overwrite:
|
||||
if prefix or postfix:
|
||||
# Add prefix and postfix
|
||||
add_pre_postfix(
|
||||
folder=images_dir_input,
|
||||
caption_file_ext=caption_file_ext,
|
||||
folder=images_dir,
|
||||
caption_file_ext=caption_ext,
|
||||
prefix=prefix,
|
||||
postfix=postfix,
|
||||
)
|
||||
if not find == '':
|
||||
if find_text:
|
||||
find_replace(
|
||||
folder=images_dir_input,
|
||||
caption_file_ext=caption_file_ext,
|
||||
find=find,
|
||||
replace=replace,
|
||||
folder=images_dir,
|
||||
caption_file_ext=caption_ext,
|
||||
find=find_text,
|
||||
replace=replace_text,
|
||||
)
|
||||
else:
|
||||
if not prefix == '' or not postfix == '':
|
||||
if prefix or postfix:
|
||||
msgbox(
|
||||
'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
|
||||
)
|
||||
@ -69,37 +67,31 @@ def caption_images(
|
||||
print('...captioning done')
|
||||
|
||||
|
||||
###
|
||||
# Gradio UI
|
||||
###
|
||||
|
||||
|
||||
def gradio_basic_caption_gui_tab():
|
||||
with gr.Tab('Basic Captioning'):
|
||||
gr.Markdown(
|
||||
'This utility will allow the creation of simple caption files for each images in a folder.'
|
||||
'This utility will allow the creation of simple caption files for each image in a folder.'
|
||||
)
|
||||
with gr.Row():
|
||||
images_dir_input = gr.Textbox(
|
||||
images_dir = gr.Textbox(
|
||||
label='Image folder to caption',
|
||||
placeholder='Directory containing the images to caption',
|
||||
interactive=True,
|
||||
)
|
||||
button_images_dir_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
button_images_dir_input.click(
|
||||
folder_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
folder_button.click(
|
||||
get_folder_path,
|
||||
outputs=images_dir_input,
|
||||
outputs=images_dir,
|
||||
show_progress=False,
|
||||
)
|
||||
caption_file_ext = gr.Textbox(
|
||||
caption_ext = gr.Textbox(
|
||||
label='Caption file extension',
|
||||
placeholder='Extention for caption file. eg: .caption, .txt',
|
||||
placeholder='Extension for caption file. eg: .caption, .txt',
|
||||
value='.txt',
|
||||
interactive=True,
|
||||
)
|
||||
overwrite_input = gr.Checkbox(
|
||||
overwrite = gr.Checkbox(
|
||||
label='Overwrite existing captions in folder',
|
||||
interactive=True,
|
||||
value=False,
|
||||
@ -110,7 +102,7 @@ def gradio_basic_caption_gui_tab():
|
||||
placeholder='(Optional)',
|
||||
interactive=True,
|
||||
)
|
||||
caption_text_input = gr.Textbox(
|
||||
caption_text = gr.Textbox(
|
||||
label='Caption text',
|
||||
placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
|
||||
interactive=True,
|
||||
@ -121,29 +113,28 @@ def gradio_basic_caption_gui_tab():
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
find = gr.Textbox(
|
||||
find_text = gr.Textbox(
|
||||
label='Find text',
|
||||
placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
|
||||
interactive=True,
|
||||
)
|
||||
replace = gr.Textbox(
|
||||
replace_text = gr.Textbox(
|
||||
label='Replacement text',
|
||||
placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing',
|
||||
interactive=True,
|
||||
)
|
||||
caption_button = gr.Button('Caption images')
|
||||
caption_button.click(
|
||||
caption_images,
|
||||
inputs=[
|
||||
caption_text,
|
||||
images_dir,
|
||||
overwrite,
|
||||
caption_ext,
|
||||
prefix,
|
||||
postfix,
|
||||
find_text,
|
||||
replace_text,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
caption_button = gr.Button('Caption images')
|
||||
|
||||
caption_button.click(
|
||||
caption_images,
|
||||
inputs=[
|
||||
caption_text_input,
|
||||
images_dir_input,
|
||||
overwrite_input,
|
||||
caption_file_ext,
|
||||
prefix,
|
||||
postfix,
|
||||
find,
|
||||
replace,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
@ -32,77 +32,112 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
|
||||
|
||||
|
||||
def update_my_data(my_data):
|
||||
if my_data.get('use_8bit_adam', False) == True:
|
||||
# Update optimizer based on use_8bit_adam flag
|
||||
use_8bit_adam = my_data.get('use_8bit_adam', False)
|
||||
if use_8bit_adam:
|
||||
my_data['optimizer'] = 'AdamW8bit'
|
||||
# my_data['use_8bit_adam'] = False
|
||||
|
||||
if (
|
||||
my_data.get('optimizer', 'missing') == 'missing'
|
||||
and my_data.get('use_8bit_adam', False) == False
|
||||
):
|
||||
elif 'optimizer' not in my_data:
|
||||
my_data['optimizer'] = 'AdamW'
|
||||
|
||||
if my_data.get('model_list', 'custom') == []:
|
||||
print('Old config with empty model list. Setting to custom...')
|
||||
my_data['model_list'] = 'custom'
|
||||
|
||||
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
||||
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
|
||||
my_data['model_list'] = 'custom'
|
||||
|
||||
# Fix old config files that contain epoch as str instead of int
|
||||
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
|
||||
model_list = my_data.get('model_list', [])
|
||||
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
|
||||
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
|
||||
my_data['model_list'] = 'custom'
|
||||
|
||||
# Convert epoch and save_every_n_epochs values to int if they are strings
|
||||
for key in ['epoch', 'save_every_n_epochs']:
|
||||
value = my_data.get(key, -1)
|
||||
if type(value) == str:
|
||||
if value != '':
|
||||
my_data[key] = int(value)
|
||||
else:
|
||||
my_data[key] = -1
|
||||
|
||||
if isinstance(value, str) and value:
|
||||
my_data[key] = int(value)
|
||||
elif not value:
|
||||
my_data[key] = -1
|
||||
|
||||
# Update LoRA_type if it is set to LoCon
|
||||
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
|
||||
my_data['LoRA_type'] = 'LyCORIS/LoCon'
|
||||
|
||||
|
||||
return my_data
|
||||
|
||||
|
||||
# def update_my_data(my_data):
|
||||
# if my_data.get('use_8bit_adam', False) == True:
|
||||
# my_data['optimizer'] = 'AdamW8bit'
|
||||
# # my_data['use_8bit_adam'] = False
|
||||
|
||||
# if (
|
||||
# my_data.get('optimizer', 'missing') == 'missing'
|
||||
# and my_data.get('use_8bit_adam', False) == False
|
||||
# ):
|
||||
# my_data['optimizer'] = 'AdamW'
|
||||
|
||||
# if my_data.get('model_list', 'custom') == []:
|
||||
# print('Old config with empty model list. Setting to custom...')
|
||||
# my_data['model_list'] = 'custom'
|
||||
|
||||
# # If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
||||
# if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
|
||||
# my_data['model_list'] = 'custom'
|
||||
|
||||
# # Fix old config files that contain epoch as str instead of int
|
||||
# for key in ['epoch', 'save_every_n_epochs']:
|
||||
# value = my_data.get(key, -1)
|
||||
# if type(value) == str:
|
||||
# if value != '':
|
||||
# my_data[key] = int(value)
|
||||
# else:
|
||||
# my_data[key] = -1
|
||||
|
||||
# if my_data.get('LoRA_type', 'Standard') == 'LoCon':
|
||||
# my_data['LoRA_type'] = 'LyCORIS/LoCon'
|
||||
|
||||
# return my_data
|
||||
|
||||
|
||||
def get_dir_and_file(file_path):
|
||||
dir_path, file_name = os.path.split(file_path)
|
||||
return (dir_path, file_name)
|
||||
|
||||
|
||||
def has_ext_files(directory, extension):
|
||||
# Iterate through all the files in the directory
|
||||
for file in os.listdir(directory):
|
||||
# If the file name ends with extension, return True
|
||||
if file.endswith(extension):
|
||||
return True
|
||||
# If no extension files were found, return False
|
||||
return False
|
||||
# def has_ext_files(directory, extension):
|
||||
# # Iterate through all the files in the directory
|
||||
# for file in os.listdir(directory):
|
||||
# # If the file name ends with extension, return True
|
||||
# if file.endswith(extension):
|
||||
# return True
|
||||
# # If no extension files were found, return False
|
||||
# return False
|
||||
|
||||
|
||||
def get_file_path(
|
||||
file_path='', defaultextension='.json', extension_name='Config files'
|
||||
file_path='', default_extension='.json', extension_name='Config files'
|
||||
):
|
||||
current_file_path = file_path
|
||||
# print(f'current file path: {current_file_path}')
|
||||
|
||||
initial_dir, initial_file = get_dir_and_file(file_path)
|
||||
|
||||
# Create a hidden Tkinter root window
|
||||
root = Tk()
|
||||
root.wm_attributes('-topmost', 1)
|
||||
root.withdraw()
|
||||
|
||||
# Show the open file dialog and get the selected file path
|
||||
file_path = filedialog.askopenfilename(
|
||||
filetypes=(
|
||||
(f'{extension_name}', f'{defaultextension}'),
|
||||
('All files', '*'),
|
||||
(extension_name, f'*{default_extension}'),
|
||||
('All files', '*.*'),
|
||||
),
|
||||
defaultextension=defaultextension,
|
||||
defaultextension=default_extension,
|
||||
initialfile=initial_file,
|
||||
initialdir=initial_dir,
|
||||
)
|
||||
|
||||
# Destroy the hidden root window
|
||||
root.destroy()
|
||||
|
||||
if file_path == '':
|
||||
# If no file is selected, use the current file path
|
||||
if not file_path:
|
||||
file_path = current_file_path
|
||||
|
||||
return file_path
|
||||
@ -230,52 +265,146 @@ def get_saveasfilename_path(
|
||||
|
||||
|
||||
def add_pre_postfix(
|
||||
folder='', prefix='', postfix='', caption_file_ext='.caption'
|
||||
):
|
||||
if not has_ext_files(folder, caption_file_ext):
|
||||
msgbox(
|
||||
f'No files with extension {caption_file_ext} were found in {folder}...'
|
||||
)
|
||||
return
|
||||
folder: str = '',
|
||||
prefix: str = '',
|
||||
postfix: str = '',
|
||||
caption_file_ext: str = '.caption'
|
||||
) -> None:
|
||||
"""
|
||||
Add prefix and/or postfix to the content of caption files within a folder.
|
||||
If no caption files are found, create one with the requested prefix and/or postfix.
|
||||
|
||||
Args:
|
||||
folder (str): Path to the folder containing caption files.
|
||||
prefix (str, optional): Prefix to add to the content of the caption files.
|
||||
postfix (str, optional): Postfix to add to the content of the caption files.
|
||||
caption_file_ext (str, optional): Extension of the caption files.
|
||||
"""
|
||||
|
||||
if prefix == '' and postfix == '':
|
||||
return
|
||||
|
||||
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
|
||||
if not prefix == '':
|
||||
prefix = f'{prefix} '
|
||||
if not postfix == '':
|
||||
postfix = f' {postfix}'
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.webp')
|
||||
image_files = [f for f in os.listdir(folder) if f.lower().endswith(image_extensions)]
|
||||
|
||||
for file in files:
|
||||
with open(os.path.join(folder, file), 'r+') as f:
|
||||
content = f.read()
|
||||
content = content.rstrip()
|
||||
f.seek(0, 0)
|
||||
f.write(f'{prefix}{content}{postfix}')
|
||||
f.close()
|
||||
for image_file in image_files:
|
||||
caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext
|
||||
caption_file_path = os.path.join(folder, caption_file_name)
|
||||
|
||||
if not os.path.exists(caption_file_path):
|
||||
with open(caption_file_path, 'w') as f:
|
||||
separator = ' ' if prefix and postfix else ''
|
||||
f.write(f'{prefix}{separator}{postfix}')
|
||||
else:
|
||||
with open(caption_file_path, 'r+') as f:
|
||||
content = f.read()
|
||||
content = content.rstrip()
|
||||
f.seek(0, 0)
|
||||
|
||||
prefix_separator = ' ' if prefix else ''
|
||||
postfix_separator = ' ' if postfix else ''
|
||||
f.write(f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}')
|
||||
|
||||
# def add_pre_postfix(
|
||||
# folder='', prefix='', postfix='', caption_file_ext='.caption'
|
||||
# ):
|
||||
# if not has_ext_files(folder, caption_file_ext):
|
||||
# msgbox(
|
||||
# f'No files with extension {caption_file_ext} were found in {folder}...'
|
||||
# )
|
||||
# return
|
||||
|
||||
# if prefix == '' and postfix == '':
|
||||
# return
|
||||
|
||||
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
|
||||
# if not prefix == '':
|
||||
# prefix = f'{prefix} '
|
||||
# if not postfix == '':
|
||||
# postfix = f' {postfix}'
|
||||
|
||||
# for file in files:
|
||||
# with open(os.path.join(folder, file), 'r+') as f:
|
||||
# content = f.read()
|
||||
# content = content.rstrip()
|
||||
# f.seek(0, 0)
|
||||
# f.write(f'{prefix} {content} {postfix}')
|
||||
# f.close()
|
||||
|
||||
|
||||
def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
|
||||
def has_ext_files(folder_path: str, file_extension: str) -> bool:
|
||||
"""
|
||||
Check if there are any files with the specified extension in the given folder.
|
||||
|
||||
Args:
|
||||
folder_path (str): Path to the folder containing files.
|
||||
file_extension (str): Extension of the files to look for.
|
||||
|
||||
Returns:
|
||||
bool: True if files with the specified extension are found, False otherwise.
|
||||
"""
|
||||
for file in os.listdir(folder_path):
|
||||
if file.endswith(file_extension):
|
||||
return True
|
||||
return False
|
||||
|
||||
def find_replace(
|
||||
folder_path: str = '',
|
||||
caption_file_ext: str = '.caption',
|
||||
search_text: str = '',
|
||||
replace_text: str = ''
|
||||
) -> None:
|
||||
"""
|
||||
Find and replace text in caption files within a folder.
|
||||
|
||||
Args:
|
||||
folder_path (str, optional): Path to the folder containing caption files.
|
||||
caption_file_ext (str, optional): Extension of the caption files.
|
||||
search_text (str, optional): Text to search for in the caption files.
|
||||
replace_text (str, optional): Text to replace the search text with.
|
||||
"""
|
||||
print('Running caption find/replace')
|
||||
if not has_ext_files(folder, caption_file_ext):
|
||||
|
||||
if not has_ext_files(folder_path, caption_file_ext):
|
||||
msgbox(
|
||||
f'No files with extension {caption_file_ext} were found in {folder}...'
|
||||
f'No files with extension {caption_file_ext} were found in {folder_path}...'
|
||||
)
|
||||
return
|
||||
|
||||
if find == '':
|
||||
if search_text == '':
|
||||
return
|
||||
|
||||
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
|
||||
for file in files:
|
||||
with open(os.path.join(folder, file), 'r', errors='ignore') as f:
|
||||
caption_files = [f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)]
|
||||
|
||||
for caption_file in caption_files:
|
||||
with open(os.path.join(folder_path, caption_file), 'r', errors='ignore') as f:
|
||||
content = f.read()
|
||||
f.close
|
||||
content = content.replace(find, replace)
|
||||
with open(os.path.join(folder, file), 'w') as f:
|
||||
|
||||
content = content.replace(search_text, replace_text)
|
||||
|
||||
with open(os.path.join(folder_path, caption_file), 'w') as f:
|
||||
f.write(content)
|
||||
f.close()
|
||||
|
||||
# def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
|
||||
# print('Running caption find/replace')
|
||||
# if not has_ext_files(folder, caption_file_ext):
|
||||
# msgbox(
|
||||
# f'No files with extension {caption_file_ext} were found in {folder}...'
|
||||
# )
|
||||
# return
|
||||
|
||||
# if find == '':
|
||||
# return
|
||||
|
||||
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
|
||||
# for file in files:
|
||||
# with open(os.path.join(folder, file), 'r', errors='ignore') as f:
|
||||
# content = f.read()
|
||||
# f.close
|
||||
# content = content.replace(find, replace)
|
||||
# with open(os.path.join(folder, file), 'w') as f:
|
||||
# f.write(content)
|
||||
# f.close()
|
||||
|
||||
|
||||
def color_aug_changed(color_aug):
|
||||
|
@ -417,13 +417,16 @@ def train_model(
|
||||
or f.endswith('.webp')
|
||||
]
|
||||
)
|
||||
|
||||
print(f'Folder {folder}: {num_images} images found')
|
||||
|
||||
# Calculate the total number of steps for this folder
|
||||
steps = repeats * num_images
|
||||
total_steps += steps
|
||||
|
||||
# Print the result
|
||||
print(f'Folder {folder}: {steps} steps')
|
||||
|
||||
total_steps += steps
|
||||
|
||||
# calculate max_train_steps
|
||||
max_train_steps = int(
|
||||
|
@ -116,7 +116,7 @@ def main():
|
||||
linear_mode_param, conv_mode_param,
|
||||
args.device,
|
||||
args.use_sparse_bias, args.sparsity,
|
||||
# not args.disable_small_conv
|
||||
not args.disable_cp
|
||||
)
|
||||
|
||||
if args.safetensors:
|
||||
|
Loading…
Reference in New Issue
Block a user