diff --git a/README.md b/README.md index a8a6596..cce5fe8 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,10 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/03/19 (v21.2.5): + - Fix basic captioning logic + - Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1. + - Update linux scripts * 2023/03/12 (v21.2.4): - Fix issue with kohya locon not training the convolution layers - Update LyCORIS module version diff --git a/dreambooth_gui.py b/dreambooth_gui.py index dee017c..7d24b70 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -377,7 +377,9 @@ def train_model( print(f'max_train_steps = {max_train_steps}') # calculate stop encoder training - if stop_text_encoder_training_pct == None: + if int(stop_text_encoder_training_pct) == -1: + stop_text_encoder_training = -1 + elif stop_text_encoder_training_pct == None: stop_text_encoder_training = 0 else: stop_text_encoder_training = math.ceil( @@ -624,7 +626,7 @@ def dreambooth_tab( placeholder='512,512', ) stop_text_encoder_training = gr.Slider( - minimum=0, + minimum=-1, maximum=100, value=0, step=1, diff --git a/gui.sh b/gui.sh index 90b26db..e4eca6f 100755 --- a/gui.sh +++ b/gui.sh @@ -1,3 +1,13 @@ #!/bin/bash + +# Activate the virtual environment source venv/bin/activate -python kohya_gui.py + +# Validate the requirements and store the exit code +python tools/validate_requirements.py +exit_code=$? + +# If the exit code is 0, run the kohya_gui.py script with the command-line arguments +if [ $exit_code -eq 0 ]; then + python kohya_gui.py "$@" +fi diff --git a/kohya_gui.py b/kohya_gui.py index d643228..f8e0d8c 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -6,7 +6,9 @@ from finetune_gui import finetune_tab from textual_inversion_gui import ti_tab from library.utilities import utilities_tab from library.extract_lora_gui import gradio_extract_lora_tab +from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab from library.merge_lora_gui import gradio_merge_lora_tab +from library.resize_lora_gui import gradio_resize_lora_tab from lora_gui import lora_tab @@ -43,7 +45,9 @@ def UI(**kwargs): enable_copy_info_button=True, ) gradio_extract_lora_tab() + gradio_extract_lycoris_locon_tab() gradio_merge_lora_tab() + gradio_resize_lora_tab() # Show the interface launch_kwargs = {} diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index 21852c3..1672880 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -6,35 +6,33 @@ import os def caption_images( - caption_text_input, - images_dir_input, - overwrite_input, - caption_file_ext, + caption_text, + images_dir, + overwrite, + caption_ext, prefix, postfix, - find, - replace, + find_text, + replace_text, ): - # Check for images_dir_input - if images_dir_input == '': + # Check for images_dir + if not images_dir: msgbox('Image folder is missing...') return - if caption_file_ext == '': + if not caption_ext: msgbox('Please provide an extension for the caption files.') return - if not caption_text_input == '': - print( - f'Captioning files in {images_dir_input} with {caption_text_input}...' - ) + if caption_text: + print(f'Captioning files in {images_dir} with {caption_text}...') run_cmd = f'python "tools/caption.py"' - run_cmd += f' --caption_text="{caption_text_input}"' - if overwrite_input: + run_cmd += f' --caption_text="{caption_text}"' + if overwrite: run_cmd += f' --overwrite' - if caption_file_ext != '': - run_cmd += f' --caption_file_ext="{caption_file_ext}"' - run_cmd += f' "{images_dir_input}"' + if caption_ext: + run_cmd += f' --caption_file_ext="{caption_ext}"' + run_cmd += f' "{images_dir}"' print(run_cmd) @@ -44,24 +42,24 @@ def caption_images( else: subprocess.run(run_cmd) - if overwrite_input: - if not prefix == '' or not postfix == '': + if overwrite: + if prefix or postfix: # Add prefix and postfix add_pre_postfix( - folder=images_dir_input, - caption_file_ext=caption_file_ext, + folder=images_dir, + caption_file_ext=caption_ext, prefix=prefix, postfix=postfix, ) - if not find == '': + if find_text: find_replace( - folder=images_dir_input, - caption_file_ext=caption_file_ext, - find=find, - replace=replace, + folder=images_dir, + caption_file_ext=caption_ext, + find=find_text, + replace=replace_text, ) else: - if not prefix == '' or not postfix == '': + if prefix or postfix: msgbox( 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...' ) @@ -69,37 +67,31 @@ def caption_images( print('...captioning done') -### # Gradio UI -### - - def gradio_basic_caption_gui_tab(): with gr.Tab('Basic Captioning'): gr.Markdown( - 'This utility will allow the creation of simple caption files for each images in a folder.' + 'This utility will allow the creation of simple caption files for each image in a folder.' ) with gr.Row(): - images_dir_input = gr.Textbox( + images_dir = gr.Textbox( label='Image folder to caption', placeholder='Directory containing the images to caption', interactive=True, ) - button_images_dir_input = gr.Button( - '📂', elem_id='open_folder_small' - ) - button_images_dir_input.click( + folder_button = gr.Button('📂', elem_id='open_folder_small') + folder_button.click( get_folder_path, - outputs=images_dir_input, + outputs=images_dir, show_progress=False, ) - caption_file_ext = gr.Textbox( + caption_ext = gr.Textbox( label='Caption file extension', - placeholder='Extention for caption file. eg: .caption, .txt', + placeholder='Extension for caption file. eg: .caption, .txt', value='.txt', interactive=True, ) - overwrite_input = gr.Checkbox( + overwrite = gr.Checkbox( label='Overwrite existing captions in folder', interactive=True, value=False, @@ -110,7 +102,7 @@ def gradio_basic_caption_gui_tab(): placeholder='(Optional)', interactive=True, ) - caption_text_input = gr.Textbox( + caption_text = gr.Textbox( label='Caption text', placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix', interactive=True, @@ -121,29 +113,28 @@ def gradio_basic_caption_gui_tab(): interactive=True, ) with gr.Row(): - find = gr.Textbox( + find_text = gr.Textbox( label='Find text', placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix', interactive=True, ) - replace = gr.Textbox( + replace_text = gr.Textbox( label='Replacement text', placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing', interactive=True, + ) + caption_button = gr.Button('Caption images') + caption_button.click( + caption_images, + inputs=[ + caption_text, + images_dir, + overwrite, + caption_ext, + prefix, + postfix, + find_text, + replace_text, + ], + show_progress=False, ) - caption_button = gr.Button('Caption images') - - caption_button.click( - caption_images, - inputs=[ - caption_text_input, - images_dir_input, - overwrite_input, - caption_file_ext, - prefix, - postfix, - find, - replace, - ], - show_progress=False, - ) diff --git a/library/common_gui.py b/library/common_gui.py index 33e93bf..05accda 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -32,77 +32,112 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS def update_my_data(my_data): - if my_data.get('use_8bit_adam', False) == True: + # Update optimizer based on use_8bit_adam flag + use_8bit_adam = my_data.get('use_8bit_adam', False) + if use_8bit_adam: my_data['optimizer'] = 'AdamW8bit' - # my_data['use_8bit_adam'] = False - - if ( - my_data.get('optimizer', 'missing') == 'missing' - and my_data.get('use_8bit_adam', False) == False - ): + elif 'optimizer' not in my_data: my_data['optimizer'] = 'AdamW' - - if my_data.get('model_list', 'custom') == []: - print('Old config with empty model list. Setting to custom...') - my_data['model_list'] = 'custom' - - # If Pretrained model name or path is not one of the preset models then set the preset_model to custom - if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS: - my_data['model_list'] = 'custom' - # Fix old config files that contain epoch as str instead of int + # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model + model_list = my_data.get('model_list', []) + pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '') + if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS: + my_data['model_list'] = 'custom' + + # Convert epoch and save_every_n_epochs values to int if they are strings for key in ['epoch', 'save_every_n_epochs']: value = my_data.get(key, -1) - if type(value) == str: - if value != '': - my_data[key] = int(value) - else: - my_data[key] = -1 - + if isinstance(value, str) and value: + my_data[key] = int(value) + elif not value: + my_data[key] = -1 + + # Update LoRA_type if it is set to LoCon if my_data.get('LoRA_type', 'Standard') == 'LoCon': my_data['LoRA_type'] = 'LyCORIS/LoCon' - + return my_data +# def update_my_data(my_data): +# if my_data.get('use_8bit_adam', False) == True: +# my_data['optimizer'] = 'AdamW8bit' +# # my_data['use_8bit_adam'] = False + +# if ( +# my_data.get('optimizer', 'missing') == 'missing' +# and my_data.get('use_8bit_adam', False) == False +# ): +# my_data['optimizer'] = 'AdamW' + +# if my_data.get('model_list', 'custom') == []: +# print('Old config with empty model list. Setting to custom...') +# my_data['model_list'] = 'custom' + +# # If Pretrained model name or path is not one of the preset models then set the preset_model to custom +# if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS: +# my_data['model_list'] = 'custom' + +# # Fix old config files that contain epoch as str instead of int +# for key in ['epoch', 'save_every_n_epochs']: +# value = my_data.get(key, -1) +# if type(value) == str: +# if value != '': +# my_data[key] = int(value) +# else: +# my_data[key] = -1 + +# if my_data.get('LoRA_type', 'Standard') == 'LoCon': +# my_data['LoRA_type'] = 'LyCORIS/LoCon' + +# return my_data + + def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) return (dir_path, file_name) -def has_ext_files(directory, extension): - # Iterate through all the files in the directory - for file in os.listdir(directory): - # If the file name ends with extension, return True - if file.endswith(extension): - return True - # If no extension files were found, return False - return False +# def has_ext_files(directory, extension): +# # Iterate through all the files in the directory +# for file in os.listdir(directory): +# # If the file name ends with extension, return True +# if file.endswith(extension): +# return True +# # If no extension files were found, return False +# return False def get_file_path( - file_path='', defaultextension='.json', extension_name='Config files' + file_path='', default_extension='.json', extension_name='Config files' ): current_file_path = file_path # print(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) + # Create a hidden Tkinter root window root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() + + # Show the open file dialog and get the selected file path file_path = filedialog.askopenfilename( filetypes=( - (f'{extension_name}', f'{defaultextension}'), - ('All files', '*'), + (extension_name, f'*{default_extension}'), + ('All files', '*.*'), ), - defaultextension=defaultextension, + defaultextension=default_extension, initialfile=initial_file, initialdir=initial_dir, ) + + # Destroy the hidden root window root.destroy() - if file_path == '': + # If no file is selected, use the current file path + if not file_path: file_path = current_file_path return file_path @@ -230,52 +265,146 @@ def get_saveasfilename_path( def add_pre_postfix( - folder='', prefix='', postfix='', caption_file_ext='.caption' -): - if not has_ext_files(folder, caption_file_ext): - msgbox( - f'No files with extension {caption_file_ext} were found in {folder}...' - ) - return + folder: str = '', + prefix: str = '', + postfix: str = '', + caption_file_ext: str = '.caption' + ) -> None: + """ + Add prefix and/or postfix to the content of caption files within a folder. + If no caption files are found, create one with the requested prefix and/or postfix. + + Args: + folder (str): Path to the folder containing caption files. + prefix (str, optional): Prefix to add to the content of the caption files. + postfix (str, optional): Postfix to add to the content of the caption files. + caption_file_ext (str, optional): Extension of the caption files. + """ if prefix == '' and postfix == '': return - files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] - if not prefix == '': - prefix = f'{prefix} ' - if not postfix == '': - postfix = f' {postfix}' + image_extensions = ('.jpg', '.jpeg', '.png', '.webp') + image_files = [f for f in os.listdir(folder) if f.lower().endswith(image_extensions)] - for file in files: - with open(os.path.join(folder, file), 'r+') as f: - content = f.read() - content = content.rstrip() - f.seek(0, 0) - f.write(f'{prefix}{content}{postfix}') - f.close() + for image_file in image_files: + caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext + caption_file_path = os.path.join(folder, caption_file_name) + + if not os.path.exists(caption_file_path): + with open(caption_file_path, 'w') as f: + separator = ' ' if prefix and postfix else '' + f.write(f'{prefix}{separator}{postfix}') + else: + with open(caption_file_path, 'r+') as f: + content = f.read() + content = content.rstrip() + f.seek(0, 0) + + prefix_separator = ' ' if prefix else '' + postfix_separator = ' ' if postfix else '' + f.write(f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}') + +# def add_pre_postfix( +# folder='', prefix='', postfix='', caption_file_ext='.caption' +# ): +# if not has_ext_files(folder, caption_file_ext): +# msgbox( +# f'No files with extension {caption_file_ext} were found in {folder}...' +# ) +# return + +# if prefix == '' and postfix == '': +# return + +# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] +# if not prefix == '': +# prefix = f'{prefix} ' +# if not postfix == '': +# postfix = f' {postfix}' + +# for file in files: +# with open(os.path.join(folder, file), 'r+') as f: +# content = f.read() +# content = content.rstrip() +# f.seek(0, 0) +# f.write(f'{prefix} {content} {postfix}') +# f.close() -def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): +def has_ext_files(folder_path: str, file_extension: str) -> bool: + """ + Check if there are any files with the specified extension in the given folder. + + Args: + folder_path (str): Path to the folder containing files. + file_extension (str): Extension of the files to look for. + + Returns: + bool: True if files with the specified extension are found, False otherwise. + """ + for file in os.listdir(folder_path): + if file.endswith(file_extension): + return True + return False + +def find_replace( + folder_path: str = '', + caption_file_ext: str = '.caption', + search_text: str = '', + replace_text: str = '' + ) -> None: + """ + Find and replace text in caption files within a folder. + + Args: + folder_path (str, optional): Path to the folder containing caption files. + caption_file_ext (str, optional): Extension of the caption files. + search_text (str, optional): Text to search for in the caption files. + replace_text (str, optional): Text to replace the search text with. + """ print('Running caption find/replace') - if not has_ext_files(folder, caption_file_ext): + + if not has_ext_files(folder_path, caption_file_ext): msgbox( - f'No files with extension {caption_file_ext} were found in {folder}...' + f'No files with extension {caption_file_ext} were found in {folder_path}...' ) return - if find == '': + if search_text == '': return - files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] - for file in files: - with open(os.path.join(folder, file), 'r', errors='ignore') as f: + caption_files = [f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)] + + for caption_file in caption_files: + with open(os.path.join(folder_path, caption_file), 'r', errors='ignore') as f: content = f.read() - f.close - content = content.replace(find, replace) - with open(os.path.join(folder, file), 'w') as f: + + content = content.replace(search_text, replace_text) + + with open(os.path.join(folder_path, caption_file), 'w') as f: f.write(content) - f.close() + +# def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): +# print('Running caption find/replace') +# if not has_ext_files(folder, caption_file_ext): +# msgbox( +# f'No files with extension {caption_file_ext} were found in {folder}...' +# ) +# return + +# if find == '': +# return + +# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] +# for file in files: +# with open(os.path.join(folder, file), 'r', errors='ignore') as f: +# content = f.read() +# f.close +# content = content.replace(find, replace) +# with open(os.path.join(folder, file), 'w') as f: +# f.write(content) +# f.close() def color_aug_changed(color_aug): diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py new file mode 100644 index 0000000..43de7a2 --- /dev/null +++ b/library/extract_lycoris_locon_gui.py @@ -0,0 +1,273 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +document_symbol = '\U0001F4C4' # 📄 +PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' + + +def extract_lycoris_locon( + db_model, base_model, output_name, device, + is_v2, mode, linear_dim, conv_dim, + linear_threshold, conv_threshold, + linear_ratio, conv_ratio, + linear_quantile, conv_quantile, + use_sparse_bias, sparsity, disable_cp +): + # Check for caption_text_input + if db_model == '': + msgbox('Invalid finetuned model file') + return + + if base_model == '': + msgbox('Invalid base model file') + return + + # Check if source model exist + if not os.path.isfile(db_model): + msgbox('The provided finetuned model is not a file') + return + + if not os.path.isfile(base_model): + msgbox('The provided base model is not a file') + return + + run_cmd = ( + f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"' + ) + if is_v2: + run_cmd += f' --is_v2' + run_cmd += f' --device {device}' + run_cmd += f' --mode {mode}' + run_cmd += f' --safetensors' + run_cmd += f' --linear_dim {linear_dim}' + run_cmd += f' --conv_dim {conv_dim}' + run_cmd += f' --linear_threshold {linear_threshold}' + run_cmd += f' --conv_threshold {conv_threshold}' + run_cmd += f' --linear_ratio {linear_ratio}' + run_cmd += f' --conv_ratio {conv_ratio}' + run_cmd += f' --linear_quantile {linear_quantile}' + run_cmd += f' --conv_quantile {conv_quantile}' + if use_sparse_bias: + run_cmd += f' --use_sparse_bias' + run_cmd += f' --sparsity {sparsity}' + if disable_cp: + run_cmd += f' --disable_cp' + run_cmd += f' "{base_model}"' + run_cmd += f' "{db_model}"' + run_cmd += f' "{output_name}"' + + print(run_cmd) + + # Run the command + if os.name == 'posix': + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + +### +# Gradio UI +### +# def update_mode(mode): +# # 'fixed', 'threshold','ratio','quantile' +# if mode == 'fixed': +# return gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False) +# if mode == 'threshold': +# return gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False) +# if mode == 'ratio': +# return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False) +# if mode == 'threshold': +# return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True) + +def update_mode(mode): + # Create a list of possible mode values + modes = ['fixed', 'threshold', 'ratio', 'quantile'] + + # Initialize an empty list to store visibility updates + updates = [] + + # Iterate through the possible modes + for m in modes: + # Add a visibility update for each mode, setting it to True if the input mode matches the current mode in the loop + updates.append(gr.Row.update(visible=(mode == m))) + + # Return the visibility updates as a tuple + return tuple(updates) + +def gradio_extract_lycoris_locon_tab(): + with gr.Tab('Extract LyCORIS LoCON'): + gr.Markdown( + 'This utility can extract a LyCORIS LoCon network from a finetuned model.' + ) + lora_ext = gr.Textbox(value='*.safetensors', visible=False) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) + model_ext_name = gr.Textbox(value='Model types', visible=False) + + with gr.Row(): + db_model = gr.Textbox( + label='Finetuned model', + placeholder='Path to the finetuned model to extract', + interactive=True, + ) + button_db_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_db_model_file.click( + get_file_path, + inputs=[db_model, model_ext, model_ext_name], + outputs=db_model, + show_progress=False, + ) + + base_model = gr.Textbox( + label='Stable Diffusion base model', + placeholder='Stable Diffusion original model: ckpt or safetensors file', + interactive=True, + ) + button_base_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_base_model_file.click( + get_file_path, + inputs=[base_model, model_ext, model_ext_name], + outputs=base_model, + show_progress=False, + ) + with gr.Row(): + output_name = gr.Textbox( + label='Save to', + placeholder='path where to save the extracted LoRA model...', + interactive=True, + ) + button_output_name = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_output_name.click( + get_saveasfilename_path, + inputs=[output_name, lora_ext, lora_ext_name], + outputs=output_name, + show_progress=False, + ) + device = gr.Dropdown( + label='Device', + choices=['cpu', 'cuda',], + value='cuda', + interactive=True, + ) + is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True) + mode = gr.Dropdown( + label='Mode', + choices=['fixed', 'threshold','ratio','quantile'], + value='fixed', + interactive=True, + ) + with gr.Row(visible=True) as fixed: + linear_dim = gr.Slider( + minimum=1, + maximum=1024, + label='Network Dimension', + value=1, + step=1, + interactive=True, + ) + conv_dim = gr.Slider( + minimum=1, + maximum=1024, + label='Conv Dimension', + value=1, + step=1, + interactive=True, + ) + with gr.Row(visible=False) as threshold: + linear_threshold = gr.Slider( + minimum=0, + maximum=1, + label='Linear threshold', + value=0, + step=0.01, + interactive=True, + ) + conv_threshold = gr.Slider( + minimum=0, + maximum=1, + label='Conv threshold', + value=0, + step=0.01, + interactive=True, + ) + with gr.Row(visible=False) as ratio: + linear_ratio = gr.Slider( + minimum=0, + maximum=1, + label='Linear ratio', + value=0, + step=0.01, + interactive=True, + ) + conv_ratio = gr.Slider( + minimum=0, + maximum=1, + label='Conv ratio', + value=0, + step=0.01, + interactive=True, + ) + with gr.Row(visible=False) as quantile: + linear_quantile = gr.Slider( + minimum=0, + maximum=1, + label='Linear quantile', + value=0.75, + step=0.01, + interactive=True, + ) + conv_quantile = gr.Slider( + minimum=0, + maximum=1, + label='Conv quantile', + value=0.75, + step=0.01, + interactive=True, + ) + with gr.Row(): + use_sparse_bias = gr.Checkbox(label='Use sparse biais', value=False, interactive=True) + sparsity = gr.Slider( + minimum=0, + maximum=1, + label='Sparsity', + value=0.98, + step=0.01, + interactive=True, + ) + disable_cp = gr.Checkbox(label='Disable CP decomposition', value=False, interactive=True) + mode.change( + update_mode, + inputs=[mode], + outputs=[ + fixed, threshold, ratio, quantile, + ] + ) + + extract_button = gr.Button('Extract LyCORIS LoCon') + + extract_button.click( + extract_lycoris_locon, + inputs=[db_model, base_model, output_name, device, + is_v2, mode, linear_dim, conv_dim, + linear_threshold, conv_threshold, + linear_ratio, conv_ratio, + linear_quantile, conv_quantile, + use_sparse_bias, sparsity, disable_cp], + show_progress=False, + ) diff --git a/lora_gui.py b/lora_gui.py index 39e690f..a043c4b 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -417,13 +417,16 @@ def train_model( or f.endswith('.webp') ] ) + + print(f'Folder {folder}: {num_images} images found') # Calculate the total number of steps for this folder steps = repeats * num_images - total_steps += steps # Print the result print(f'Folder {folder}: {steps} steps') + + total_steps += steps # calculate max_train_steps max_train_steps = int( diff --git a/tools/lycoris_locon_extract.py b/tools/lycoris_locon_extract.py index 2b10375..75b5549 100644 --- a/tools/lycoris_locon_extract.py +++ b/tools/lycoris_locon_extract.py @@ -116,7 +116,7 @@ def main(): linear_mode_param, conv_mode_param, args.device, args.use_sparse_bias, args.sparsity, - # not args.disable_small_conv + not args.disable_cp ) if args.safetensors: diff --git a/tools/validate_requirements.py b/tools/validate_requirements.py index 9af2ce4..f158bdc 100644 --- a/tools/validate_requirements.py +++ b/tools/validate_requirements.py @@ -1,3 +1,4 @@ +import os import sys import pkg_resources @@ -32,7 +33,8 @@ if missing_requirements or wrong_version_requirements: print("Error: The following packages have the wrong version:") for requirement, expected_version, actual_version in wrong_version_requirements: print(f" - {requirement} (expected version {expected_version}, found version {actual_version})") - print('\nRun \033[33mupgrade.ps1\033[0m or \033[33mpip install -U -r requirements.txt\033[0m to resolve the missing requirements listed above...') + upgrade_script = "upgrade.ps1" if os.name == "nt" else "upgrade.sh" + print(f"\nRun \033[33m{upgrade_script}\033[0m or \033[33mpip install -U -r requirements.txt\033[0m to resolve the missing requirements listed above...") sys.exit(1) diff --git a/upgrade.sh b/upgrade.sh new file mode 100755 index 0000000..f01e7b7 --- /dev/null +++ b/upgrade.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Check if there are any changes that need to be committed +if [[ -n $(git status --short) ]]; then + echo "There are changes that need to be committed. Please stash or undo your changes before running this script." >&2 + exit 1 +fi + +# Pull the latest changes from the remote repository +git pull + +# Activate the virtual environment +source venv/bin/activate + +# Upgrade the required packages +pip install --upgrade -r requirements.txt