diff --git a/kohya_gui.py b/kohya_gui.py index bc7faa8..1031810 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -19,7 +19,7 @@ def UI(username, password): print('Load CSS...') css += file.read() + '\n' - interface = gr.Blocks(css=css) + interface = gr.Blocks(css=css, title="Kohya_ss GUI") with interface: with gr.Tab('Dreambooth'): diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index b058c0f..0f14dd8 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -109,11 +109,11 @@ def gradio_extract_lora_tab(): ) with gr.Row(): dim = gr.Slider( - minimum=1, - maximum=128, + minimum=4, + maximum=1024, label='Network Dimension', - value=8, - step=1, + value=128, + step=4, interactive=True, ) v2 = gr.Checkbox(label='v2', value=False, interactive=True) diff --git a/library/git_caption_gui.py b/library/git_caption_gui.py new file mode 100644 index 0000000..1863324 --- /dev/null +++ b/library/git_caption_gui.py @@ -0,0 +1,126 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import get_folder_path, add_pre_postfix + + +def caption_images( + train_data_dir, + caption_ext, + batch_size, + max_data_loader_n_workers, + max_length, + model_id, + prefix, + postfix, +): + # Check for images_dir_input + if train_data_dir == '': + msgbox('Image folder is missing...') + return + + if caption_ext == '': + msgbox('Please provide an extension for the caption files.') + return + + print(f'GIT captioning files in {train_data_dir}...') + run_cmd = f'.\\venv\\Scripts\\python.exe "finetune/make_captions.py"' + if not model_id == '': + run_cmd += f' --model_id="{model_id}"' + run_cmd += f' --batch_size="{int(batch_size)}"' + run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' + run_cmd += f' --max_length="{int(max_length)}"' + if caption_ext != '': + run_cmd += f' --caption_extension="{caption_ext}"' + run_cmd += f' "{train_data_dir}"' + + print(run_cmd) + + # Run the command + subprocess.run(run_cmd) + + # Add prefix and postfix + add_pre_postfix( + folder=train_data_dir, + caption_file_ext=caption_ext, + prefix=prefix, + postfix=postfix, + ) + + print('...captioning done') + + +### +# Gradio UI +### + + +def gradio_git_caption_gui_tab(): + with gr.Tab('GIT Captioning'): + gr.Markdown( + 'This utility will use GIT to caption files for each images in a folder.' + ) + with gr.Row(): + train_data_dir = gr.Textbox( + label='Image folder to caption', + placeholder='Directory containing the images to caption', + interactive=True, + ) + button_train_data_dir_input = gr.Button( + '📂', elem_id='open_folder_small' + ) + button_train_data_dir_input.click( + get_folder_path, outputs=train_data_dir + ) + with gr.Row(): + caption_ext = gr.Textbox( + label='Caption file extension', + placeholder='Extention for caption file. eg: .caption, .txt', + value='.txt', + interactive=True, + ) + + prefix = gr.Textbox( + label='Prefix to add to BLIP caption', + placeholder='(Optional)', + interactive=True, + ) + + postfix = gr.Textbox( + label='Postfix to add to BLIP caption', + placeholder='(Optional)', + interactive=True, + ) + + batch_size = gr.Number( + value=1, label='Batch size', interactive=True + ) + + with gr.Row(): + max_data_loader_n_workers = gr.Number( + value=2, label='Number of workers', interactive=True + ) + max_length = gr.Number( + value=75, label='Max length', interactive=True + ) + model_id = gr.Textbox( + label="Model", + placeholder="(Optional) model id for GIT in Hugging Face", interactive=True + ) + + caption_button = gr.Button('Caption images') + + caption_button.click( + caption_images, + inputs=[ + train_data_dir, + caption_ext, + batch_size, + max_data_loader_n_workers, + max_length, + model_id, + prefix, + postfix, + ], + ) diff --git a/library/utilities.py b/library/utilities.py index 17795bc..523c2c2 100644 --- a/library/utilities.py +++ b/library/utilities.py @@ -9,6 +9,7 @@ import argparse from library.basic_caption_gui import gradio_basic_caption_gui_tab from library.convert_model_gui import gradio_convert_model_tab from library.blip_caption_gui import gradio_blip_caption_gui_tab +from library.git_caption_gui import gradio_git_caption_gui_tab from library.wd14_caption_gui import gradio_wd14_caption_gui_tab @@ -23,6 +24,7 @@ def utilities_tab( with gr.Tab('Captioning'): gradio_basic_caption_gui_tab() gradio_blip_caption_gui_tab() + gradio_git_caption_gui_tab() gradio_wd14_caption_gui_tab() gradio_convert_model_tab() diff --git a/lora_gui.py b/lora_gui.py index 70dbfd0..de7c36c 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -291,11 +291,11 @@ def train_model( if unet_lr == '': unet_lr = 0 - if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0): - msgbox( - 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided' - ) - return + # if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0): + # msgbox( + # 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided' + # ) + # return # Get a list of all subfolders in train_data_dir subfolders = [ @@ -383,15 +383,26 @@ def train_model( if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --network_module=networks.lora' - if not float(text_encoder_lr) == 0: - run_cmd += f' --text_encoder_lr={text_encoder_lr}' + + if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): + if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): + run_cmd += f' --text_encoder_lr={text_encoder_lr}' + run_cmd += f' --unet_lr={unet_lr}' + elif not (float(text_encoder_lr) == 0): + run_cmd += f' --text_encoder_lr={text_encoder_lr}' + run_cmd += f' --network_train_text_encoder_only' + else: + run_cmd += f' --unet_lr={unet_lr}' + run_cmd += f' --network_train_unet_only' else: - run_cmd += f' --network_train_unet_only' - if not float(unet_lr) == 0: - run_cmd += f' --unet_lr={unet_lr}' - else: - run_cmd += f' --network_train_text_encoder_only' + if float(text_encoder_lr) == 0: + msgbox( + 'Please input learning rate values.' + ) + return + run_cmd += f' --network_dim={network_dim}' + if not lora_network_weights == '': run_cmd += f' --network_weights="{lora_network_weights}"' if int(gradient_accumulation_steps) > 1: @@ -400,6 +411,8 @@ def train_model( run_cmd += f' --output_name="{output_name}"' if not lr_scheduler_num_cycles == '': run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"' + else: + run_cmd += f' --lr_scheduler_num_cycles="{epoch}"' if not lr_scheduler_power == '': run_cmd += f' --output_name="{lr_scheduler_power}"' @@ -612,19 +625,19 @@ def lora_tab( placeholder='Optional', ) network_dim = gr.Slider( - minimum=1, - maximum=128, + minimum=4, + maximum=1024, label='Network Rank (Dimension)', value=8, - step=1, + step=4, interactive=True, ) network_alpha = gr.Slider( - minimum=1, - maximum=128, + minimum=4, + maximum=1024, label='Network Alpha', value=1, - step=1, + step=4, interactive=True, ) with gr.Row(): diff --git a/requirements.txt b/requirements.txt index 709a834..eeb0bdc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,13 +9,14 @@ pytorch_lightning bitsandbytes==0.35.0 tensorboard safetensors==0.2.6 -gradio +gradio==3.16.2 altair easygui +tk # for BLIP captioning requests -timm==0.4.12 -fairscale==0.4.4 +timm +fairscale # for WD14 captioning tensorflow<2.11 huggingface-hub