KohyaSS/library/git_caption_gui.py
JSTayco 7b5639cff5 Huge WIP
This is a massive WIP and should not be trusted or used right now. However, major milestones have been crossed. Both message boxes and file dialogs are now properly subprocessed and work on macOS. I think by extension, it may work on runpod environments as well, but that remains to be tested.
2023-03-30 01:40:00 -07:00

138 lines
3.7 KiB
Python

import os
import subprocess
import gradio as gr
from .common_gui import get_folder_path, add_pre_postfix
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
def caption_images(
train_data_dir,
caption_ext,
batch_size,
max_data_loader_n_workers,
max_length,
model_id,
prefix,
postfix,
):
# Check for images_dir_input
if train_data_dir == '':
show_message_box('Image folder is missing...')
return
if caption_ext == '':
show_message_box('Please provide an extension for the caption files.')
return
print(f'GIT captioning files in {train_data_dir}...')
run_cmd = (
f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"'
)
if not model_id == '':
run_cmd += f' --model_id="{model_id}"'
run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += (
f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
)
run_cmd += f' --max_length="{int(max_length)}"'
if caption_ext != '':
run_cmd += f' --caption_extension="{caption_ext}"'
run_cmd += f' "{train_data_dir}"'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
# Add prefix and postfix
add_pre_postfix(
folder=train_data_dir,
caption_file_ext=caption_ext,
prefix=prefix,
postfix=postfix,
)
print('...captioning done')
###
# Gradio UI
###
def gradio_git_caption_gui_tab():
with gr.Tab('GIT Captioning'):
gr.Markdown(
'This utility will use GIT to caption files for each images in a folder.'
)
with gr.Row():
train_data_dir = gr.Textbox(
label='Image folder to caption',
placeholder='Directory containing the images to caption',
interactive=True,
)
button_train_data_dir_input = gr.Button(
'📂', elem_id='open_folder_small'
)
button_train_data_dir_input.click(
get_folder_path,
outputs=train_data_dir,
show_progress=False,
)
with gr.Row():
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extention for caption file. eg: .caption, .txt',
value='.txt',
interactive=True,
)
prefix = gr.Textbox(
label='Prefix to add to BLIP caption',
placeholder='(Optional)',
interactive=True,
)
postfix = gr.Textbox(
label='Postfix to add to BLIP caption',
placeholder='(Optional)',
interactive=True,
)
batch_size = gr.Number(
value=1, label='Batch size', interactive=True
)
with gr.Row():
max_data_loader_n_workers = gr.Number(
value=2, label='Number of workers', interactive=True
)
max_length = gr.Number(
value=75, label='Max length', interactive=True
)
model_id = gr.Textbox(
label='Model',
placeholder='(Optional) model id for GIT in Hugging Face',
interactive=True,
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[
train_data_dir,
caption_ext,
batch_size,
max_data_loader_n_workers,
max_length,
model_id,
prefix,
postfix,
],
show_progress=False,
)