KohyaSS/library/dreambooth_folder_creation_gui.py

212 lines
7.8 KiB
Python
Raw Normal View History

import os
import shutil
2022-12-16 18:16:23 +00:00
import gradio as gr
2022-12-16 18:16:23 +00:00
from .common_gui import get_folder_path
2022-12-17 01:26:26 +00:00
2022-12-23 01:18:51 +00:00
def copy_info_to_Folders_tab(training_folder):
2022-12-17 01:26:26 +00:00
img_folder = os.path.join(training_folder, 'img')
if os.path.exists(os.path.join(training_folder, 'reg')):
reg_folder = os.path.join(training_folder, 'reg')
2022-12-16 18:16:23 +00:00
else:
2022-12-17 01:26:26 +00:00
reg_folder = ''
model_folder = os.path.join(training_folder, 'model')
log_folder = os.path.join(training_folder, 'log')
2022-12-16 18:16:23 +00:00
return img_folder, reg_folder, model_folder, log_folder
def dreambooth_folder_preparation(
util_training_images_dir_input,
util_training_images_repeat_input,
util_instance_prompt_input,
util_regularization_images_dir_input,
util_regularization_images_repeat_input,
util_class_prompt_input,
util_training_dir_output,
):
# Check if the input variables are empty
2022-12-17 01:26:26 +00:00
if not len(util_training_dir_output):
2022-12-16 18:16:23 +00:00
print(
"Destination training directory is missing... can't perform the required task..."
)
return
else:
# Create the util_training_dir_output directory if it doesn't exist
os.makedirs(util_training_dir_output, exist_ok=True)
# Check for instance prompt
2022-12-17 01:26:26 +00:00
if util_instance_prompt_input == '':
show_message_box('Instance prompt missing...')
2022-12-16 18:16:23 +00:00
return
2022-12-17 01:26:26 +00:00
2022-12-16 18:16:23 +00:00
# Check for class prompt
2022-12-17 01:26:26 +00:00
if util_class_prompt_input == '':
show_message_box('Class prompt missing...')
2022-12-16 18:16:23 +00:00
return
# Create the training_dir path
2022-12-17 01:26:26 +00:00
if util_training_images_dir_input == '':
2022-12-16 18:16:23 +00:00
print(
"Training images directory is missing... can't perform the required task..."
)
return
else:
training_dir = os.path.join(
util_training_dir_output,
2022-12-17 01:26:26 +00:00
f'img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}',
2022-12-16 18:16:23 +00:00
)
# Remove folders if they exist
if os.path.exists(training_dir):
2022-12-17 01:26:26 +00:00
print(f'Removing existing directory {training_dir}...')
2022-12-16 18:16:23 +00:00
shutil.rmtree(training_dir)
# Copy the training images to their respective directories
2022-12-17 01:26:26 +00:00
print(f'Copy {util_training_images_dir_input} to {training_dir}...')
2022-12-16 18:16:23 +00:00
shutil.copytree(util_training_images_dir_input, training_dir)
2022-12-22 16:51:34 +00:00
if not util_regularization_images_dir_input == '':
# Create the regularization_dir path
if not util_regularization_images_repeat_input > 0:
print('Repeats is missing... not copying regularisation images...')
else:
regularization_dir = os.path.join(
util_training_dir_output,
f'reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}',
)
2022-12-16 18:16:23 +00:00
2022-12-22 16:51:34 +00:00
# Remove folders if they exist
if os.path.exists(regularization_dir):
print(f'Removing existing directory {regularization_dir}...')
shutil.rmtree(regularization_dir)
2022-12-16 18:16:23 +00:00
2022-12-22 16:51:34 +00:00
# Copy the regularisation images to their respective directories
print(
f'Copy {util_regularization_images_dir_input} to {regularization_dir}...'
)
shutil.copytree(
util_regularization_images_dir_input, regularization_dir
)
else:
2022-12-16 18:16:23 +00:00
print(
2022-12-22 16:51:34 +00:00
'Regularization images directory is missing... not copying regularisation images...'
2022-12-16 18:16:23 +00:00
)
# create log and model folder
# Check if the log folder exists and create it if it doesn't
if not os.path.exists(os.path.join(util_training_dir_output, 'log')):
os.makedirs(os.path.join(util_training_dir_output, 'log'))
# Check if the model folder exists and create it if it doesn't
if not os.path.exists(os.path.join(util_training_dir_output, 'model')):
os.makedirs(os.path.join(util_training_dir_output, 'model'))
2022-12-16 18:16:23 +00:00
print(
2022-12-17 01:26:26 +00:00
f'Done creating kohya_ss training folder structure at {util_training_dir_output}...'
2022-12-16 18:16:23 +00:00
)
2022-12-17 01:26:26 +00:00
def gradio_dreambooth_folder_creation_tab(
2022-12-22 16:51:34 +00:00
train_data_dir_input=gr.Textbox(),
reg_data_dir_input=gr.Textbox(),
output_dir_input=gr.Textbox(),
logging_dir_input=gr.Textbox(),
2022-12-17 01:26:26 +00:00
):
2022-12-30 02:17:41 +00:00
with gr.Tab('Dreambooth/LoRA Folder preparation'):
2022-12-16 18:16:23 +00:00
gr.Markdown(
2022-12-30 02:17:41 +00:00
'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly.'
2022-12-16 18:16:23 +00:00
)
with gr.Row():
util_instance_prompt_input = gr.Textbox(
2022-12-17 01:26:26 +00:00
label='Instance prompt',
placeholder='Eg: asd',
2022-12-16 18:16:23 +00:00
interactive=True,
)
util_class_prompt_input = gr.Textbox(
2022-12-17 01:26:26 +00:00
label='Class prompt',
placeholder='Eg: person',
2022-12-16 18:16:23 +00:00
interactive=True,
)
with gr.Row():
util_training_images_dir_input = gr.Textbox(
2022-12-17 01:26:26 +00:00
label='Training images',
placeholder='Directory containing the training images',
2022-12-16 18:16:23 +00:00
interactive=True,
)
button_util_training_images_dir_input = gr.Button(
2022-12-17 01:26:26 +00:00
'📂', elem_id='open_folder_small'
)
2022-12-16 18:16:23 +00:00
button_util_training_images_dir_input.click(
2023-03-04 23:56:22 +00:00
get_folder_path,
outputs=util_training_images_dir_input,
show_progress=False,
2022-12-17 01:26:26 +00:00
)
2022-12-16 18:16:23 +00:00
util_training_images_repeat_input = gr.Number(
2022-12-17 01:26:26 +00:00
label='Repeats',
2022-12-16 18:16:23 +00:00
value=40,
interactive=True,
2022-12-17 01:26:26 +00:00
elem_id='number_input',
)
2022-12-16 18:16:23 +00:00
with gr.Row():
util_regularization_images_dir_input = gr.Textbox(
2022-12-17 01:26:26 +00:00
label='Regularisation images',
placeholder='(Optional) Directory containing the regularisation images',
2022-12-16 18:16:23 +00:00
interactive=True,
)
button_util_regularization_images_dir_input = gr.Button(
2022-12-17 01:26:26 +00:00
'📂', elem_id='open_folder_small'
)
2022-12-16 18:16:23 +00:00
button_util_regularization_images_dir_input.click(
2023-03-04 23:56:22 +00:00
get_folder_path,
outputs=util_regularization_images_dir_input,
show_progress=False,
2022-12-17 01:26:26 +00:00
)
2022-12-16 18:16:23 +00:00
util_regularization_images_repeat_input = gr.Number(
2022-12-17 01:26:26 +00:00
label='Repeats',
2022-12-16 18:16:23 +00:00
value=1,
interactive=True,
2022-12-17 01:26:26 +00:00
elem_id='number_input',
)
2022-12-16 18:16:23 +00:00
with gr.Row():
util_training_dir_output = gr.Textbox(
2022-12-17 01:26:26 +00:00
label='Destination training directory',
placeholder='Directory where formatted training and regularisation folders will be placed',
2022-12-16 18:16:23 +00:00
interactive=True,
)
button_util_training_dir_output = gr.Button(
2022-12-17 01:26:26 +00:00
'📂', elem_id='open_folder_small'
)
2022-12-16 18:16:23 +00:00
button_util_training_dir_output.click(
2022-12-17 01:26:26 +00:00
get_folder_path, outputs=util_training_dir_output
)
button_prepare_training_data = gr.Button('Prepare training data')
2022-12-16 18:16:23 +00:00
button_prepare_training_data.click(
dreambooth_folder_preparation,
inputs=[
util_training_images_dir_input,
util_training_images_repeat_input,
util_instance_prompt_input,
util_regularization_images_dir_input,
util_regularization_images_repeat_input,
util_class_prompt_input,
util_training_dir_output,
],
2023-03-04 23:56:22 +00:00
show_progress=False,
2022-12-16 18:16:23 +00:00
)
2023-02-06 01:07:00 +00:00
button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab')
2022-12-23 01:18:51 +00:00
button_copy_info_to_Folders_tab.click(
copy_info_to_Folders_tab,
inputs=[util_training_dir_output],
outputs=[
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
],
2023-03-04 23:56:22 +00:00
show_progress=False,
2022-12-23 01:18:51 +00:00
)