198 lines
7.0 KiB
Python
198 lines
7.0 KiB
Python
import gradio as gr
|
|
from easygui import diropenbox, msgbox
|
|
from .common_gui import get_folder_path
|
|
import shutil
|
|
import os
|
|
|
|
|
|
def copy_info_to_Directories_tab(training_folder):
|
|
img_folder = os.path.join(training_folder, 'img')
|
|
if os.path.exists(os.path.join(training_folder, 'reg')):
|
|
reg_folder = os.path.join(training_folder, 'reg')
|
|
else:
|
|
reg_folder = ''
|
|
model_folder = os.path.join(training_folder, 'model')
|
|
log_folder = os.path.join(training_folder, 'log')
|
|
|
|
return img_folder, reg_folder, model_folder, log_folder
|
|
|
|
|
|
def dreambooth_folder_preparation(
|
|
util_training_images_dir_input,
|
|
util_training_images_repeat_input,
|
|
util_instance_prompt_input,
|
|
util_regularization_images_dir_input,
|
|
util_regularization_images_repeat_input,
|
|
util_class_prompt_input,
|
|
util_training_dir_output,
|
|
):
|
|
|
|
# Check if the input variables are empty
|
|
if not len(util_training_dir_output):
|
|
print(
|
|
"Destination training directory is missing... can't perform the required task..."
|
|
)
|
|
return
|
|
else:
|
|
# Create the util_training_dir_output directory if it doesn't exist
|
|
os.makedirs(util_training_dir_output, exist_ok=True)
|
|
|
|
# Check for instance prompt
|
|
if util_instance_prompt_input == '':
|
|
msgbox('Instance prompt missing...')
|
|
return
|
|
|
|
# Check for class prompt
|
|
if util_class_prompt_input == '':
|
|
msgbox('Class prompt missing...')
|
|
return
|
|
|
|
# Create the training_dir path
|
|
if util_training_images_dir_input == '':
|
|
print(
|
|
"Training images directory is missing... can't perform the required task..."
|
|
)
|
|
return
|
|
else:
|
|
training_dir = os.path.join(
|
|
util_training_dir_output,
|
|
f'img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}',
|
|
)
|
|
|
|
# Remove folders if they exist
|
|
if os.path.exists(training_dir):
|
|
print(f'Removing existing directory {training_dir}...')
|
|
shutil.rmtree(training_dir)
|
|
|
|
# Copy the training images to their respective directories
|
|
print(f'Copy {util_training_images_dir_input} to {training_dir}...')
|
|
shutil.copytree(util_training_images_dir_input, training_dir)
|
|
|
|
# Create the regularization_dir path
|
|
if (
|
|
util_class_prompt_input == ''
|
|
or not util_regularization_images_repeat_input > 0
|
|
):
|
|
print(
|
|
'Regularization images directory or repeats is missing... not copying regularisation images...'
|
|
)
|
|
else:
|
|
regularization_dir = os.path.join(
|
|
util_training_dir_output,
|
|
f'reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}',
|
|
)
|
|
|
|
# Remove folders if they exist
|
|
if os.path.exists(regularization_dir):
|
|
print(f'Removing existing directory {regularization_dir}...')
|
|
shutil.rmtree(regularization_dir)
|
|
|
|
# Copy the regularisation images to their respective directories
|
|
print(
|
|
f'Copy {util_regularization_images_dir_input} to {regularization_dir}...'
|
|
)
|
|
shutil.copytree(
|
|
util_regularization_images_dir_input, regularization_dir
|
|
)
|
|
|
|
print(
|
|
f'Done creating kohya_ss training folder structure at {util_training_dir_output}...'
|
|
)
|
|
|
|
|
|
def gradio_dreambooth_folder_creation_tab(
|
|
train_data_dir_input,
|
|
reg_data_dir_input,
|
|
output_dir_input,
|
|
logging_dir_input,
|
|
):
|
|
with gr.Tab('Dreambooth folder preparation'):
|
|
gr.Markdown(
|
|
'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth method to function correctly.'
|
|
)
|
|
with gr.Row():
|
|
util_instance_prompt_input = gr.Textbox(
|
|
label='Instance prompt',
|
|
placeholder='Eg: asd',
|
|
interactive=True,
|
|
)
|
|
util_class_prompt_input = gr.Textbox(
|
|
label='Class prompt',
|
|
placeholder='Eg: person',
|
|
interactive=True,
|
|
)
|
|
with gr.Row():
|
|
util_training_images_dir_input = gr.Textbox(
|
|
label='Training images',
|
|
placeholder='Directory containing the training images',
|
|
interactive=True,
|
|
)
|
|
button_util_training_images_dir_input = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
button_util_training_images_dir_input.click(
|
|
get_folder_path, outputs=util_training_images_dir_input
|
|
)
|
|
util_training_images_repeat_input = gr.Number(
|
|
label='Repeats',
|
|
value=40,
|
|
interactive=True,
|
|
elem_id='number_input',
|
|
)
|
|
with gr.Row():
|
|
util_regularization_images_dir_input = gr.Textbox(
|
|
label='Regularisation images',
|
|
placeholder='(Optional) Directory containing the regularisation images',
|
|
interactive=True,
|
|
)
|
|
button_util_regularization_images_dir_input = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
button_util_regularization_images_dir_input.click(
|
|
get_folder_path, outputs=util_regularization_images_dir_input
|
|
)
|
|
util_regularization_images_repeat_input = gr.Number(
|
|
label='Repeats',
|
|
value=1,
|
|
interactive=True,
|
|
elem_id='number_input',
|
|
)
|
|
with gr.Row():
|
|
util_training_dir_output = gr.Textbox(
|
|
label='Destination training directory',
|
|
placeholder='Directory where formatted training and regularisation folders will be placed',
|
|
interactive=True,
|
|
)
|
|
button_util_training_dir_output = gr.Button(
|
|
'📂', elem_id='open_folder_small'
|
|
)
|
|
button_util_training_dir_output.click(
|
|
get_folder_path, outputs=util_training_dir_output
|
|
)
|
|
button_prepare_training_data = gr.Button('Prepare training data')
|
|
button_prepare_training_data.click(
|
|
dreambooth_folder_preparation,
|
|
inputs=[
|
|
util_training_images_dir_input,
|
|
util_training_images_repeat_input,
|
|
util_instance_prompt_input,
|
|
util_regularization_images_dir_input,
|
|
util_regularization_images_repeat_input,
|
|
util_class_prompt_input,
|
|
util_training_dir_output,
|
|
],
|
|
)
|
|
button_copy_info_to_Directories_tab = gr.Button(
|
|
'Copy info to Directories Tab'
|
|
)
|
|
button_copy_info_to_Directories_tab.click(
|
|
copy_info_to_Directories_tab,
|
|
inputs=[util_training_dir_output],
|
|
outputs=[
|
|
train_data_dir_input,
|
|
reg_data_dir_input,
|
|
output_dir_input,
|
|
logging_dir_input,
|
|
],
|
|
)
|