From baf009d2b1b86f5dc45a9952df9e30f6cf81bb24 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Wed, 15 Mar 2023 19:31:52 -0400 Subject: [PATCH] Fix basic captioning logic --- README.md | 2 + library/basic_caption_gui.py | 113 +++++++------- library/common_gui.py | 261 ++++++++++++++++++++++++--------- lora_gui.py | 5 +- tools/lycoris_locon_extract.py | 2 +- 5 files changed, 254 insertions(+), 129 deletions(-) diff --git a/README.md b/README.md index a8a6596..5b1fd15 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,8 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/03/16 (v21.2.5): + - Fix basic captioning logic * 2023/03/12 (v21.2.4): - Fix issue with kohya locon not training the convolution layers - Update LyCORIS module version diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index 21852c3..1672880 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -6,35 +6,33 @@ import os def caption_images( - caption_text_input, - images_dir_input, - overwrite_input, - caption_file_ext, + caption_text, + images_dir, + overwrite, + caption_ext, prefix, postfix, - find, - replace, + find_text, + replace_text, ): - # Check for images_dir_input - if images_dir_input == '': + # Check for images_dir + if not images_dir: msgbox('Image folder is missing...') return - if caption_file_ext == '': + if not caption_ext: msgbox('Please provide an extension for the caption files.') return - if not caption_text_input == '': - print( - f'Captioning files in {images_dir_input} with {caption_text_input}...' - ) + if caption_text: + print(f'Captioning files in {images_dir} with {caption_text}...') run_cmd = f'python "tools/caption.py"' - run_cmd += f' --caption_text="{caption_text_input}"' - if overwrite_input: + run_cmd += f' --caption_text="{caption_text}"' + if overwrite: run_cmd += f' --overwrite' - if caption_file_ext != '': - run_cmd += f' --caption_file_ext="{caption_file_ext}"' - run_cmd += f' "{images_dir_input}"' + if caption_ext: + run_cmd += f' --caption_file_ext="{caption_ext}"' + run_cmd += f' "{images_dir}"' print(run_cmd) @@ -44,24 +42,24 @@ def caption_images( else: subprocess.run(run_cmd) - if overwrite_input: - if not prefix == '' or not postfix == '': + if overwrite: + if prefix or postfix: # Add prefix and postfix add_pre_postfix( - folder=images_dir_input, - caption_file_ext=caption_file_ext, + folder=images_dir, + caption_file_ext=caption_ext, prefix=prefix, postfix=postfix, ) - if not find == '': + if find_text: find_replace( - folder=images_dir_input, - caption_file_ext=caption_file_ext, - find=find, - replace=replace, + folder=images_dir, + caption_file_ext=caption_ext, + find=find_text, + replace=replace_text, ) else: - if not prefix == '' or not postfix == '': + if prefix or postfix: msgbox( 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...' ) @@ -69,37 +67,31 @@ def caption_images( print('...captioning done') -### # Gradio UI -### - - def gradio_basic_caption_gui_tab(): with gr.Tab('Basic Captioning'): gr.Markdown( - 'This utility will allow the creation of simple caption files for each images in a folder.' + 'This utility will allow the creation of simple caption files for each image in a folder.' ) with gr.Row(): - images_dir_input = gr.Textbox( + images_dir = gr.Textbox( label='Image folder to caption', placeholder='Directory containing the images to caption', interactive=True, ) - button_images_dir_input = gr.Button( - '📂', elem_id='open_folder_small' - ) - button_images_dir_input.click( + folder_button = gr.Button('📂', elem_id='open_folder_small') + folder_button.click( get_folder_path, - outputs=images_dir_input, + outputs=images_dir, show_progress=False, ) - caption_file_ext = gr.Textbox( + caption_ext = gr.Textbox( label='Caption file extension', - placeholder='Extention for caption file. eg: .caption, .txt', + placeholder='Extension for caption file. eg: .caption, .txt', value='.txt', interactive=True, ) - overwrite_input = gr.Checkbox( + overwrite = gr.Checkbox( label='Overwrite existing captions in folder', interactive=True, value=False, @@ -110,7 +102,7 @@ def gradio_basic_caption_gui_tab(): placeholder='(Optional)', interactive=True, ) - caption_text_input = gr.Textbox( + caption_text = gr.Textbox( label='Caption text', placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix', interactive=True, @@ -121,29 +113,28 @@ def gradio_basic_caption_gui_tab(): interactive=True, ) with gr.Row(): - find = gr.Textbox( + find_text = gr.Textbox( label='Find text', placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix', interactive=True, ) - replace = gr.Textbox( + replace_text = gr.Textbox( label='Replacement text', placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing', interactive=True, + ) + caption_button = gr.Button('Caption images') + caption_button.click( + caption_images, + inputs=[ + caption_text, + images_dir, + overwrite, + caption_ext, + prefix, + postfix, + find_text, + replace_text, + ], + show_progress=False, ) - caption_button = gr.Button('Caption images') - - caption_button.click( - caption_images, - inputs=[ - caption_text_input, - images_dir_input, - overwrite_input, - caption_file_ext, - prefix, - postfix, - find, - replace, - ], - show_progress=False, - ) diff --git a/library/common_gui.py b/library/common_gui.py index 33e93bf..05accda 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -32,77 +32,112 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS def update_my_data(my_data): - if my_data.get('use_8bit_adam', False) == True: + # Update optimizer based on use_8bit_adam flag + use_8bit_adam = my_data.get('use_8bit_adam', False) + if use_8bit_adam: my_data['optimizer'] = 'AdamW8bit' - # my_data['use_8bit_adam'] = False - - if ( - my_data.get('optimizer', 'missing') == 'missing' - and my_data.get('use_8bit_adam', False) == False - ): + elif 'optimizer' not in my_data: my_data['optimizer'] = 'AdamW' - - if my_data.get('model_list', 'custom') == []: - print('Old config with empty model list. Setting to custom...') - my_data['model_list'] = 'custom' - - # If Pretrained model name or path is not one of the preset models then set the preset_model to custom - if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS: - my_data['model_list'] = 'custom' - # Fix old config files that contain epoch as str instead of int + # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model + model_list = my_data.get('model_list', []) + pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '') + if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS: + my_data['model_list'] = 'custom' + + # Convert epoch and save_every_n_epochs values to int if they are strings for key in ['epoch', 'save_every_n_epochs']: value = my_data.get(key, -1) - if type(value) == str: - if value != '': - my_data[key] = int(value) - else: - my_data[key] = -1 - + if isinstance(value, str) and value: + my_data[key] = int(value) + elif not value: + my_data[key] = -1 + + # Update LoRA_type if it is set to LoCon if my_data.get('LoRA_type', 'Standard') == 'LoCon': my_data['LoRA_type'] = 'LyCORIS/LoCon' - + return my_data +# def update_my_data(my_data): +# if my_data.get('use_8bit_adam', False) == True: +# my_data['optimizer'] = 'AdamW8bit' +# # my_data['use_8bit_adam'] = False + +# if ( +# my_data.get('optimizer', 'missing') == 'missing' +# and my_data.get('use_8bit_adam', False) == False +# ): +# my_data['optimizer'] = 'AdamW' + +# if my_data.get('model_list', 'custom') == []: +# print('Old config with empty model list. Setting to custom...') +# my_data['model_list'] = 'custom' + +# # If Pretrained model name or path is not one of the preset models then set the preset_model to custom +# if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS: +# my_data['model_list'] = 'custom' + +# # Fix old config files that contain epoch as str instead of int +# for key in ['epoch', 'save_every_n_epochs']: +# value = my_data.get(key, -1) +# if type(value) == str: +# if value != '': +# my_data[key] = int(value) +# else: +# my_data[key] = -1 + +# if my_data.get('LoRA_type', 'Standard') == 'LoCon': +# my_data['LoRA_type'] = 'LyCORIS/LoCon' + +# return my_data + + def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) return (dir_path, file_name) -def has_ext_files(directory, extension): - # Iterate through all the files in the directory - for file in os.listdir(directory): - # If the file name ends with extension, return True - if file.endswith(extension): - return True - # If no extension files were found, return False - return False +# def has_ext_files(directory, extension): +# # Iterate through all the files in the directory +# for file in os.listdir(directory): +# # If the file name ends with extension, return True +# if file.endswith(extension): +# return True +# # If no extension files were found, return False +# return False def get_file_path( - file_path='', defaultextension='.json', extension_name='Config files' + file_path='', default_extension='.json', extension_name='Config files' ): current_file_path = file_path # print(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) + # Create a hidden Tkinter root window root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() + + # Show the open file dialog and get the selected file path file_path = filedialog.askopenfilename( filetypes=( - (f'{extension_name}', f'{defaultextension}'), - ('All files', '*'), + (extension_name, f'*{default_extension}'), + ('All files', '*.*'), ), - defaultextension=defaultextension, + defaultextension=default_extension, initialfile=initial_file, initialdir=initial_dir, ) + + # Destroy the hidden root window root.destroy() - if file_path == '': + # If no file is selected, use the current file path + if not file_path: file_path = current_file_path return file_path @@ -230,52 +265,146 @@ def get_saveasfilename_path( def add_pre_postfix( - folder='', prefix='', postfix='', caption_file_ext='.caption' -): - if not has_ext_files(folder, caption_file_ext): - msgbox( - f'No files with extension {caption_file_ext} were found in {folder}...' - ) - return + folder: str = '', + prefix: str = '', + postfix: str = '', + caption_file_ext: str = '.caption' + ) -> None: + """ + Add prefix and/or postfix to the content of caption files within a folder. + If no caption files are found, create one with the requested prefix and/or postfix. + + Args: + folder (str): Path to the folder containing caption files. + prefix (str, optional): Prefix to add to the content of the caption files. + postfix (str, optional): Postfix to add to the content of the caption files. + caption_file_ext (str, optional): Extension of the caption files. + """ if prefix == '' and postfix == '': return - files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] - if not prefix == '': - prefix = f'{prefix} ' - if not postfix == '': - postfix = f' {postfix}' + image_extensions = ('.jpg', '.jpeg', '.png', '.webp') + image_files = [f for f in os.listdir(folder) if f.lower().endswith(image_extensions)] - for file in files: - with open(os.path.join(folder, file), 'r+') as f: - content = f.read() - content = content.rstrip() - f.seek(0, 0) - f.write(f'{prefix}{content}{postfix}') - f.close() + for image_file in image_files: + caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext + caption_file_path = os.path.join(folder, caption_file_name) + + if not os.path.exists(caption_file_path): + with open(caption_file_path, 'w') as f: + separator = ' ' if prefix and postfix else '' + f.write(f'{prefix}{separator}{postfix}') + else: + with open(caption_file_path, 'r+') as f: + content = f.read() + content = content.rstrip() + f.seek(0, 0) + + prefix_separator = ' ' if prefix else '' + postfix_separator = ' ' if postfix else '' + f.write(f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}') + +# def add_pre_postfix( +# folder='', prefix='', postfix='', caption_file_ext='.caption' +# ): +# if not has_ext_files(folder, caption_file_ext): +# msgbox( +# f'No files with extension {caption_file_ext} were found in {folder}...' +# ) +# return + +# if prefix == '' and postfix == '': +# return + +# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] +# if not prefix == '': +# prefix = f'{prefix} ' +# if not postfix == '': +# postfix = f' {postfix}' + +# for file in files: +# with open(os.path.join(folder, file), 'r+') as f: +# content = f.read() +# content = content.rstrip() +# f.seek(0, 0) +# f.write(f'{prefix} {content} {postfix}') +# f.close() -def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): +def has_ext_files(folder_path: str, file_extension: str) -> bool: + """ + Check if there are any files with the specified extension in the given folder. + + Args: + folder_path (str): Path to the folder containing files. + file_extension (str): Extension of the files to look for. + + Returns: + bool: True if files with the specified extension are found, False otherwise. + """ + for file in os.listdir(folder_path): + if file.endswith(file_extension): + return True + return False + +def find_replace( + folder_path: str = '', + caption_file_ext: str = '.caption', + search_text: str = '', + replace_text: str = '' + ) -> None: + """ + Find and replace text in caption files within a folder. + + Args: + folder_path (str, optional): Path to the folder containing caption files. + caption_file_ext (str, optional): Extension of the caption files. + search_text (str, optional): Text to search for in the caption files. + replace_text (str, optional): Text to replace the search text with. + """ print('Running caption find/replace') - if not has_ext_files(folder, caption_file_ext): + + if not has_ext_files(folder_path, caption_file_ext): msgbox( - f'No files with extension {caption_file_ext} were found in {folder}...' + f'No files with extension {caption_file_ext} were found in {folder_path}...' ) return - if find == '': + if search_text == '': return - files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] - for file in files: - with open(os.path.join(folder, file), 'r', errors='ignore') as f: + caption_files = [f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)] + + for caption_file in caption_files: + with open(os.path.join(folder_path, caption_file), 'r', errors='ignore') as f: content = f.read() - f.close - content = content.replace(find, replace) - with open(os.path.join(folder, file), 'w') as f: + + content = content.replace(search_text, replace_text) + + with open(os.path.join(folder_path, caption_file), 'w') as f: f.write(content) - f.close() + +# def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): +# print('Running caption find/replace') +# if not has_ext_files(folder, caption_file_ext): +# msgbox( +# f'No files with extension {caption_file_ext} were found in {folder}...' +# ) +# return + +# if find == '': +# return + +# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] +# for file in files: +# with open(os.path.join(folder, file), 'r', errors='ignore') as f: +# content = f.read() +# f.close +# content = content.replace(find, replace) +# with open(os.path.join(folder, file), 'w') as f: +# f.write(content) +# f.close() def color_aug_changed(color_aug): diff --git a/lora_gui.py b/lora_gui.py index 39e690f..a043c4b 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -417,13 +417,16 @@ def train_model( or f.endswith('.webp') ] ) + + print(f'Folder {folder}: {num_images} images found') # Calculate the total number of steps for this folder steps = repeats * num_images - total_steps += steps # Print the result print(f'Folder {folder}: {steps} steps') + + total_steps += steps # calculate max_train_steps max_train_steps = int( diff --git a/tools/lycoris_locon_extract.py b/tools/lycoris_locon_extract.py index 2b10375..75b5549 100644 --- a/tools/lycoris_locon_extract.py +++ b/tools/lycoris_locon_extract.py @@ -116,7 +116,7 @@ def main(): linear_mode_param, conv_mode_param, args.device, args.use_sparse_bias, args.sparsity, - # not args.disable_small_conv + not args.disable_cp ) if args.safetensors: