Fix basic captioning logic

This commit is contained in:
bmaltais 2023-03-15 19:31:52 -04:00
parent 7a94c523f5
commit baf009d2b1
5 changed files with 254 additions and 129 deletions

View File

@ -189,6 +189,8 @@ This will store your a backup file with your current locally installed pip packa
## Change History ## Change History
* 2023/03/16 (v21.2.5):
- Fix basic captioning logic
* 2023/03/12 (v21.2.4): * 2023/03/12 (v21.2.4):
- Fix issue with kohya locon not training the convolution layers - Fix issue with kohya locon not training the convolution layers
- Update LyCORIS module version - Update LyCORIS module version

View File

@ -6,35 +6,33 @@ import os
def caption_images( def caption_images(
caption_text_input, caption_text,
images_dir_input, images_dir,
overwrite_input, overwrite,
caption_file_ext, caption_ext,
prefix, prefix,
postfix, postfix,
find, find_text,
replace, replace_text,
): ):
# Check for images_dir_input # Check for images_dir
if images_dir_input == '': if not images_dir:
msgbox('Image folder is missing...') msgbox('Image folder is missing...')
return return
if caption_file_ext == '': if not caption_ext:
msgbox('Please provide an extension for the caption files.') msgbox('Please provide an extension for the caption files.')
return return
if not caption_text_input == '': if caption_text:
print( print(f'Captioning files in {images_dir} with {caption_text}...')
f'Captioning files in {images_dir_input} with {caption_text_input}...'
)
run_cmd = f'python "tools/caption.py"' run_cmd = f'python "tools/caption.py"'
run_cmd += f' --caption_text="{caption_text_input}"' run_cmd += f' --caption_text="{caption_text}"'
if overwrite_input: if overwrite:
run_cmd += f' --overwrite' run_cmd += f' --overwrite'
if caption_file_ext != '': if caption_ext:
run_cmd += f' --caption_file_ext="{caption_file_ext}"' run_cmd += f' --caption_file_ext="{caption_ext}"'
run_cmd += f' "{images_dir_input}"' run_cmd += f' "{images_dir}"'
print(run_cmd) print(run_cmd)
@ -44,24 +42,24 @@ def caption_images(
else: else:
subprocess.run(run_cmd) subprocess.run(run_cmd)
if overwrite_input: if overwrite:
if not prefix == '' or not postfix == '': if prefix or postfix:
# Add prefix and postfix # Add prefix and postfix
add_pre_postfix( add_pre_postfix(
folder=images_dir_input, folder=images_dir,
caption_file_ext=caption_file_ext, caption_file_ext=caption_ext,
prefix=prefix, prefix=prefix,
postfix=postfix, postfix=postfix,
) )
if not find == '': if find_text:
find_replace( find_replace(
folder=images_dir_input, folder=images_dir,
caption_file_ext=caption_file_ext, caption_file_ext=caption_ext,
find=find, find=find_text,
replace=replace, replace=replace_text,
) )
else: else:
if not prefix == '' or not postfix == '': if prefix or postfix:
msgbox( msgbox(
'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...' 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
) )
@ -69,37 +67,31 @@ def caption_images(
print('...captioning done') print('...captioning done')
###
# Gradio UI # Gradio UI
###
def gradio_basic_caption_gui_tab(): def gradio_basic_caption_gui_tab():
with gr.Tab('Basic Captioning'): with gr.Tab('Basic Captioning'):
gr.Markdown( gr.Markdown(
'This utility will allow the creation of simple caption files for each images in a folder.' 'This utility will allow the creation of simple caption files for each image in a folder.'
) )
with gr.Row(): with gr.Row():
images_dir_input = gr.Textbox( images_dir = gr.Textbox(
label='Image folder to caption', label='Image folder to caption',
placeholder='Directory containing the images to caption', placeholder='Directory containing the images to caption',
interactive=True, interactive=True,
) )
button_images_dir_input = gr.Button( folder_button = gr.Button('📂', elem_id='open_folder_small')
'📂', elem_id='open_folder_small' folder_button.click(
)
button_images_dir_input.click(
get_folder_path, get_folder_path,
outputs=images_dir_input, outputs=images_dir,
show_progress=False, show_progress=False,
) )
caption_file_ext = gr.Textbox( caption_ext = gr.Textbox(
label='Caption file extension', label='Caption file extension',
placeholder='Extention for caption file. eg: .caption, .txt', placeholder='Extension for caption file. eg: .caption, .txt',
value='.txt', value='.txt',
interactive=True, interactive=True,
) )
overwrite_input = gr.Checkbox( overwrite = gr.Checkbox(
label='Overwrite existing captions in folder', label='Overwrite existing captions in folder',
interactive=True, interactive=True,
value=False, value=False,
@ -110,7 +102,7 @@ def gradio_basic_caption_gui_tab():
placeholder='(Optional)', placeholder='(Optional)',
interactive=True, interactive=True,
) )
caption_text_input = gr.Textbox( caption_text = gr.Textbox(
label='Caption text', label='Caption text',
placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix', placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
interactive=True, interactive=True,
@ -121,29 +113,28 @@ def gradio_basic_caption_gui_tab():
interactive=True, interactive=True,
) )
with gr.Row(): with gr.Row():
find = gr.Textbox( find_text = gr.Textbox(
label='Find text', label='Find text',
placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix', placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
interactive=True, interactive=True,
) )
replace = gr.Textbox( replace_text = gr.Textbox(
label='Replacement text', label='Replacement text',
placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing', placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing',
interactive=True, interactive=True,
) )
caption_button = gr.Button('Caption images') caption_button = gr.Button('Caption images')
caption_button.click( caption_button.click(
caption_images, caption_images,
inputs=[ inputs=[
caption_text_input, caption_text,
images_dir_input, images_dir,
overwrite_input, overwrite,
caption_file_ext, caption_ext,
prefix, prefix,
postfix, postfix,
find, find_text,
replace, replace_text,
], ],
show_progress=False, show_progress=False,
) )

View File

@ -32,77 +32,112 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
def update_my_data(my_data): def update_my_data(my_data):
if my_data.get('use_8bit_adam', False) == True: # Update optimizer based on use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False)
if use_8bit_adam:
my_data['optimizer'] = 'AdamW8bit' my_data['optimizer'] = 'AdamW8bit'
# my_data['use_8bit_adam'] = False elif 'optimizer' not in my_data:
if (
my_data.get('optimizer', 'missing') == 'missing'
and my_data.get('use_8bit_adam', False) == False
):
my_data['optimizer'] = 'AdamW' my_data['optimizer'] = 'AdamW'
if my_data.get('model_list', 'custom') == []: # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
print('Old config with empty model list. Setting to custom...') model_list = my_data.get('model_list', [])
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
my_data['model_list'] = 'custom' my_data['model_list'] = 'custom'
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom # Convert epoch and save_every_n_epochs values to int if they are strings
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
my_data['model_list'] = 'custom'
# Fix old config files that contain epoch as str instead of int
for key in ['epoch', 'save_every_n_epochs']: for key in ['epoch', 'save_every_n_epochs']:
value = my_data.get(key, -1) value = my_data.get(key, -1)
if type(value) == str: if isinstance(value, str) and value:
if value != '':
my_data[key] = int(value) my_data[key] = int(value)
else: elif not value:
my_data[key] = -1 my_data[key] = -1
# Update LoRA_type if it is set to LoCon
if my_data.get('LoRA_type', 'Standard') == 'LoCon': if my_data.get('LoRA_type', 'Standard') == 'LoCon':
my_data['LoRA_type'] = 'LyCORIS/LoCon' my_data['LoRA_type'] = 'LyCORIS/LoCon'
return my_data return my_data
# def update_my_data(my_data):
# if my_data.get('use_8bit_adam', False) == True:
# my_data['optimizer'] = 'AdamW8bit'
# # my_data['use_8bit_adam'] = False
# if (
# my_data.get('optimizer', 'missing') == 'missing'
# and my_data.get('use_8bit_adam', False) == False
# ):
# my_data['optimizer'] = 'AdamW'
# if my_data.get('model_list', 'custom') == []:
# print('Old config with empty model list. Setting to custom...')
# my_data['model_list'] = 'custom'
# # If Pretrained model name or path is not one of the preset models then set the preset_model to custom
# if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
# my_data['model_list'] = 'custom'
# # Fix old config files that contain epoch as str instead of int
# for key in ['epoch', 'save_every_n_epochs']:
# value = my_data.get(key, -1)
# if type(value) == str:
# if value != '':
# my_data[key] = int(value)
# else:
# my_data[key] = -1
# if my_data.get('LoRA_type', 'Standard') == 'LoCon':
# my_data['LoRA_type'] = 'LyCORIS/LoCon'
# return my_data
def get_dir_and_file(file_path): def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path) dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name) return (dir_path, file_name)
def has_ext_files(directory, extension): # def has_ext_files(directory, extension):
# Iterate through all the files in the directory # # Iterate through all the files in the directory
for file in os.listdir(directory): # for file in os.listdir(directory):
# If the file name ends with extension, return True # # If the file name ends with extension, return True
if file.endswith(extension): # if file.endswith(extension):
return True # return True
# If no extension files were found, return False # # If no extension files were found, return False
return False # return False
def get_file_path( def get_file_path(
file_path='', defaultextension='.json', extension_name='Config files' file_path='', default_extension='.json', extension_name='Config files'
): ):
current_file_path = file_path current_file_path = file_path
# print(f'current file path: {current_file_path}') # print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path) initial_dir, initial_file = get_dir_and_file(file_path)
# Create a hidden Tkinter root window
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
# Show the open file dialog and get the selected file path
file_path = filedialog.askopenfilename( file_path = filedialog.askopenfilename(
filetypes=( filetypes=(
(f'{extension_name}', f'{defaultextension}'), (extension_name, f'*{default_extension}'),
('All files', '*'), ('All files', '*.*'),
), ),
defaultextension=defaultextension, defaultextension=default_extension,
initialfile=initial_file, initialfile=initial_file,
initialdir=initial_dir, initialdir=initial_dir,
) )
# Destroy the hidden root window
root.destroy() root.destroy()
if file_path == '': # If no file is selected, use the current file path
if not file_path:
file_path = current_file_path file_path = current_file_path
return file_path return file_path
@ -230,52 +265,146 @@ def get_saveasfilename_path(
def add_pre_postfix( def add_pre_postfix(
folder='', prefix='', postfix='', caption_file_ext='.caption' folder: str = '',
): prefix: str = '',
if not has_ext_files(folder, caption_file_ext): postfix: str = '',
msgbox( caption_file_ext: str = '.caption'
f'No files with extension {caption_file_ext} were found in {folder}...' ) -> None:
) """
return Add prefix and/or postfix to the content of caption files within a folder.
If no caption files are found, create one with the requested prefix and/or postfix.
Args:
folder (str): Path to the folder containing caption files.
prefix (str, optional): Prefix to add to the content of the caption files.
postfix (str, optional): Postfix to add to the content of the caption files.
caption_file_ext (str, optional): Extension of the caption files.
"""
if prefix == '' and postfix == '': if prefix == '' and postfix == '':
return return
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] image_extensions = ('.jpg', '.jpeg', '.png', '.webp')
if not prefix == '': image_files = [f for f in os.listdir(folder) if f.lower().endswith(image_extensions)]
prefix = f'{prefix} '
if not postfix == '':
postfix = f' {postfix}'
for file in files: for image_file in image_files:
with open(os.path.join(folder, file), 'r+') as f: caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext
caption_file_path = os.path.join(folder, caption_file_name)
if not os.path.exists(caption_file_path):
with open(caption_file_path, 'w') as f:
separator = ' ' if prefix and postfix else ''
f.write(f'{prefix}{separator}{postfix}')
else:
with open(caption_file_path, 'r+') as f:
content = f.read() content = f.read()
content = content.rstrip() content = content.rstrip()
f.seek(0, 0) f.seek(0, 0)
f.write(f'{prefix}{content}{postfix}')
f.close() prefix_separator = ' ' if prefix else ''
postfix_separator = ' ' if postfix else ''
f.write(f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}')
# def add_pre_postfix(
# folder='', prefix='', postfix='', caption_file_ext='.caption'
# ):
# if not has_ext_files(folder, caption_file_ext):
# msgbox(
# f'No files with extension {caption_file_ext} were found in {folder}...'
# )
# return
# if prefix == '' and postfix == '':
# return
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
# if not prefix == '':
# prefix = f'{prefix} '
# if not postfix == '':
# postfix = f' {postfix}'
# for file in files:
# with open(os.path.join(folder, file), 'r+') as f:
# content = f.read()
# content = content.rstrip()
# f.seek(0, 0)
# f.write(f'{prefix} {content} {postfix}')
# f.close()
def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): def has_ext_files(folder_path: str, file_extension: str) -> bool:
"""
Check if there are any files with the specified extension in the given folder.
Args:
folder_path (str): Path to the folder containing files.
file_extension (str): Extension of the files to look for.
Returns:
bool: True if files with the specified extension are found, False otherwise.
"""
for file in os.listdir(folder_path):
if file.endswith(file_extension):
return True
return False
def find_replace(
folder_path: str = '',
caption_file_ext: str = '.caption',
search_text: str = '',
replace_text: str = ''
) -> None:
"""
Find and replace text in caption files within a folder.
Args:
folder_path (str, optional): Path to the folder containing caption files.
caption_file_ext (str, optional): Extension of the caption files.
search_text (str, optional): Text to search for in the caption files.
replace_text (str, optional): Text to replace the search text with.
"""
print('Running caption find/replace') print('Running caption find/replace')
if not has_ext_files(folder, caption_file_ext):
if not has_ext_files(folder_path, caption_file_ext):
msgbox( msgbox(
f'No files with extension {caption_file_ext} were found in {folder}...' f'No files with extension {caption_file_ext} were found in {folder_path}...'
) )
return return
if find == '': if search_text == '':
return return
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] caption_files = [f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)]
for file in files:
with open(os.path.join(folder, file), 'r', errors='ignore') as f: for caption_file in caption_files:
with open(os.path.join(folder_path, caption_file), 'r', errors='ignore') as f:
content = f.read() content = f.read()
f.close
content = content.replace(find, replace) content = content.replace(search_text, replace_text)
with open(os.path.join(folder, file), 'w') as f:
with open(os.path.join(folder_path, caption_file), 'w') as f:
f.write(content) f.write(content)
f.close()
# def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
# print('Running caption find/replace')
# if not has_ext_files(folder, caption_file_ext):
# msgbox(
# f'No files with extension {caption_file_ext} were found in {folder}...'
# )
# return
# if find == '':
# return
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
# for file in files:
# with open(os.path.join(folder, file), 'r', errors='ignore') as f:
# content = f.read()
# f.close
# content = content.replace(find, replace)
# with open(os.path.join(folder, file), 'w') as f:
# f.write(content)
# f.close()
def color_aug_changed(color_aug): def color_aug_changed(color_aug):

View File

@ -418,13 +418,16 @@ def train_model(
] ]
) )
print(f'Folder {folder}: {num_images} images found')
# Calculate the total number of steps for this folder # Calculate the total number of steps for this folder
steps = repeats * num_images steps = repeats * num_images
total_steps += steps
# Print the result # Print the result
print(f'Folder {folder}: {steps} steps') print(f'Folder {folder}: {steps} steps')
total_steps += steps
# calculate max_train_steps # calculate max_train_steps
max_train_steps = int( max_train_steps = int(
math.ceil( math.ceil(

View File

@ -116,7 +116,7 @@ def main():
linear_mode_param, conv_mode_param, linear_mode_param, conv_mode_param,
args.device, args.device,
args.use_sparse_bias, args.sparsity, args.use_sparse_bias, args.sparsity,
# not args.disable_small_conv not args.disable_cp
) )
if args.safetensors: if args.safetensors: