commit
b78df38979
@ -130,6 +130,9 @@ Drop by the discord server for support: https://discord.com/channels/10415185624
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 12/19 (v18.4) update:
|
||||||
|
- Add support for shuffle_caption, save_state, resume, prior_loss_weight under "Advanced Configuration" section
|
||||||
|
- Fix issue with open/save config not working properly
|
||||||
* 12/19 (v18.3) update:
|
* 12/19 (v18.3) update:
|
||||||
- fix stop encoder training issue
|
- fix stop encoder training issue
|
||||||
* 12/19 (v18.2) update:
|
* 12/19 (v18.2) update:
|
||||||
|
@ -10,7 +10,9 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
from library.dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab
|
from library.dreambooth_folder_creation_gui import (
|
||||||
|
gradio_dreambooth_folder_creation_tab,
|
||||||
|
)
|
||||||
from library.basic_caption_gui import gradio_basic_caption_gui_tab
|
from library.basic_caption_gui import gradio_basic_caption_gui_tab
|
||||||
from library.convert_model_gui import gradio_convert_model_tab
|
from library.convert_model_gui import gradio_convert_model_tab
|
||||||
from library.blip_caption_gui import gradio_blip_caption_gui_tab
|
from library.blip_caption_gui import gradio_blip_caption_gui_tab
|
||||||
@ -20,14 +22,14 @@ from library.common_gui import (
|
|||||||
get_folder_path,
|
get_folder_path,
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
get_saveasfile_path
|
get_saveasfile_path,
|
||||||
)
|
)
|
||||||
from easygui import msgbox
|
from easygui import msgbox
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
|
|
||||||
def save_configuration(
|
def save_configuration(
|
||||||
@ -60,7 +62,11 @@ def save_configuration(
|
|||||||
stop_text_encoder_training,
|
stop_text_encoder_training,
|
||||||
use_8bit_adam,
|
use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as
|
save_model_as,
|
||||||
|
shuffle_caption,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
prior_loss_weight,
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
@ -68,22 +74,14 @@ def save_configuration(
|
|||||||
|
|
||||||
if save_as_bool:
|
if save_as_bool:
|
||||||
print('Save as...')
|
print('Save as...')
|
||||||
# file_path = filesavebox(
|
|
||||||
# 'Select the config file to save',
|
|
||||||
# default='finetune.json',
|
|
||||||
# filetypes='*.json',
|
|
||||||
# )
|
|
||||||
file_path = get_saveasfile_path(file_path)
|
file_path = get_saveasfile_path(file_path)
|
||||||
else:
|
else:
|
||||||
print('Save...')
|
print('Save...')
|
||||||
if file_path == None or file_path == '':
|
if file_path == None or file_path == '':
|
||||||
# file_path = filesavebox(
|
|
||||||
# 'Select the config file to save',
|
|
||||||
# default='finetune.json',
|
|
||||||
# filetypes='*.json',
|
|
||||||
# )
|
|
||||||
file_path = get_saveasfile_path(file_path)
|
file_path = get_saveasfile_path(file_path)
|
||||||
|
|
||||||
|
# print(file_path)
|
||||||
|
|
||||||
if file_path == None or file_path == '':
|
if file_path == None or file_path == '':
|
||||||
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
|
|
||||||
@ -116,7 +114,11 @@ def save_configuration(
|
|||||||
'stop_text_encoder_training': stop_text_encoder_training,
|
'stop_text_encoder_training': stop_text_encoder_training,
|
||||||
'use_8bit_adam': use_8bit_adam,
|
'use_8bit_adam': use_8bit_adam,
|
||||||
'xformers': xformers,
|
'xformers': xformers,
|
||||||
'save_model_as': save_model_as
|
'save_model_as': save_model_as,
|
||||||
|
'shuffle_caption': shuffle_caption,
|
||||||
|
'save_state': save_state,
|
||||||
|
'resume': resume,
|
||||||
|
'prior_loss_weight': prior_loss_weight,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save the data to the selected file
|
# Save the data to the selected file
|
||||||
@ -155,14 +157,18 @@ def open_configuration(
|
|||||||
stop_text_encoder_training,
|
stop_text_encoder_training,
|
||||||
use_8bit_adam,
|
use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as
|
save_model_as,
|
||||||
|
shuffle_caption,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
prior_loss_weight,
|
||||||
):
|
):
|
||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
file_path = get_file_path(file_path)
|
file_path = get_file_path(file_path)
|
||||||
|
# print(file_path)
|
||||||
|
|
||||||
if file_path != '' and file_path != None:
|
if not file_path == '' and not file_path == None:
|
||||||
print(file_path)
|
|
||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data = json.load(f)
|
my_data = json.load(f)
|
||||||
@ -204,7 +210,11 @@ def open_configuration(
|
|||||||
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
|
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
|
||||||
my_data.get('use_8bit_adam', use_8bit_adam),
|
my_data.get('use_8bit_adam', use_8bit_adam),
|
||||||
my_data.get('xformers', xformers),
|
my_data.get('xformers', xformers),
|
||||||
my_data.get('save_model_as', save_model_as)
|
my_data.get('save_model_as', save_model_as),
|
||||||
|
my_data.get('shuffle_caption', shuffle_caption),
|
||||||
|
my_data.get('save_state', save_state),
|
||||||
|
my_data.get('resume', resume),
|
||||||
|
my_data.get('prior_loss_weight', prior_loss_weight),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -236,7 +246,11 @@ def train_model(
|
|||||||
stop_text_encoder_training_pct,
|
stop_text_encoder_training_pct,
|
||||||
use_8bit_adam,
|
use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
save_model_as
|
save_model_as,
|
||||||
|
shuffle_caption,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
prior_loss_weight,
|
||||||
):
|
):
|
||||||
def save_inference_file(output_dir, v2, v_parameterization):
|
def save_inference_file(output_dir, v2, v_parameterization):
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
@ -360,6 +374,10 @@ def train_model(
|
|||||||
run_cmd += ' --use_8bit_adam'
|
run_cmd += ' --use_8bit_adam'
|
||||||
if xformers:
|
if xformers:
|
||||||
run_cmd += ' --xformers'
|
run_cmd += ' --xformers'
|
||||||
|
if shuffle_caption:
|
||||||
|
run_cmd += ' --shuffle_caption'
|
||||||
|
if save_state:
|
||||||
|
run_cmd += ' --save_state'
|
||||||
run_cmd += (
|
run_cmd += (
|
||||||
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
||||||
)
|
)
|
||||||
@ -382,9 +400,15 @@ def train_model(
|
|||||||
run_cmd += f' --logging_dir={logging_dir}'
|
run_cmd += f' --logging_dir={logging_dir}'
|
||||||
run_cmd += f' --caption_extention={caption_extention}'
|
run_cmd += f' --caption_extention={caption_extention}'
|
||||||
if not stop_text_encoder_training == 0:
|
if not stop_text_encoder_training == 0:
|
||||||
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
|
run_cmd += (
|
||||||
|
f' --stop_text_encoder_training={stop_text_encoder_training}'
|
||||||
|
)
|
||||||
if not save_model_as == 'same as source model':
|
if not save_model_as == 'same as source model':
|
||||||
run_cmd += f' --save_model_as={save_model_as}'
|
run_cmd += f' --save_model_as={save_model_as}'
|
||||||
|
if not resume == '':
|
||||||
|
run_cmd += f' --resume={resume}'
|
||||||
|
if not float(prior_loss_weight) == 1.0:
|
||||||
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -392,7 +416,7 @@ def train_model(
|
|||||||
|
|
||||||
# check if output_dir/last is a folder... therefore it is a diffuser model
|
# check if output_dir/last is a folder... therefore it is a diffuser model
|
||||||
last_dir = pathlib.Path(f'{output_dir}/last')
|
last_dir = pathlib.Path(f'{output_dir}/last')
|
||||||
|
|
||||||
if not last_dir.is_dir():
|
if not last_dir.is_dir():
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
save_inference_file(output_dir, v2, v_parameterization)
|
save_inference_file(output_dir, v2, v_parameterization)
|
||||||
@ -472,8 +496,8 @@ with interface:
|
|||||||
)
|
)
|
||||||
config_file_name = gr.Textbox(
|
config_file_name = gr.Textbox(
|
||||||
label='',
|
label='',
|
||||||
# placeholder="type the configuration file path or use the 'Open' button above to select it...",
|
placeholder="type the configuration file path or use the 'Open' button above to select it...",
|
||||||
interactive=False
|
interactive=True,
|
||||||
)
|
)
|
||||||
# config_file_name.change(
|
# config_file_name.change(
|
||||||
# remove_doublequote,
|
# remove_doublequote,
|
||||||
@ -491,13 +515,16 @@ with interface:
|
|||||||
document_symbol, elem_id='open_folder_small'
|
document_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_fille.click(
|
pretrained_model_name_or_path_fille.click(
|
||||||
get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input
|
get_file_path,
|
||||||
|
inputs=[pretrained_model_name_or_path_input],
|
||||||
|
outputs=pretrained_model_name_or_path_input,
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_folder = gr.Button(
|
pretrained_model_name_or_path_folder = gr.Button(
|
||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_folder.click(
|
pretrained_model_name_or_path_folder.click(
|
||||||
get_folder_path, outputs=pretrained_model_name_or_path_input
|
get_folder_path,
|
||||||
|
outputs=pretrained_model_name_or_path_input,
|
||||||
)
|
)
|
||||||
model_list = gr.Dropdown(
|
model_list = gr.Dropdown(
|
||||||
label='(Optional) Model Quick Pick',
|
label='(Optional) Model Quick Pick',
|
||||||
@ -517,10 +544,10 @@ with interface:
|
|||||||
'same as source model',
|
'same as source model',
|
||||||
'ckpt',
|
'ckpt',
|
||||||
'diffusers',
|
'diffusers',
|
||||||
"diffusers_safetensors",
|
'diffusers_safetensors',
|
||||||
'safetensors',
|
'safetensors',
|
||||||
],
|
],
|
||||||
value='same as source model'
|
value='same as source model',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
v2_input = gr.Checkbox(label='v2', value=True)
|
v2_input = gr.Checkbox(label='v2', value=True)
|
||||||
@ -607,7 +634,9 @@ with interface:
|
|||||||
)
|
)
|
||||||
with gr.Tab('Training parameters'):
|
with gr.Tab('Training parameters'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
learning_rate_input = gr.Textbox(
|
||||||
|
label='Learning rate', value=1e-6
|
||||||
|
)
|
||||||
lr_scheduler_input = gr.Dropdown(
|
lr_scheduler_input = gr.Dropdown(
|
||||||
label='LR Scheduler',
|
label='LR Scheduler',
|
||||||
choices=[
|
choices=[
|
||||||
@ -662,7 +691,9 @@ with interface:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
seed_input = gr.Textbox(label='Seed', value=1234)
|
||||||
max_resolution_input = gr.Textbox(
|
max_resolution_input = gr.Textbox(
|
||||||
label='Max resolution', value='512,512', placeholder='512,512'
|
label='Max resolution',
|
||||||
|
value='512,512',
|
||||||
|
placeholder='512,512',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
caption_extention_input = gr.Textbox(
|
caption_extention_input = gr.Textbox(
|
||||||
@ -676,27 +707,45 @@ with interface:
|
|||||||
step=1,
|
step=1,
|
||||||
label='Stop text encoder training',
|
label='Stop text encoder training',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
|
||||||
full_fp16_input = gr.Checkbox(
|
|
||||||
label='Full fp16 training (experimental)', value=False
|
|
||||||
)
|
|
||||||
no_token_padding_input = gr.Checkbox(
|
|
||||||
label='No token padding', value=False
|
|
||||||
)
|
|
||||||
|
|
||||||
gradient_checkpointing_input = gr.Checkbox(
|
|
||||||
label='Gradient checkpointing', value=False
|
|
||||||
)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
enable_bucket_input = gr.Checkbox(
|
enable_bucket_input = gr.Checkbox(
|
||||||
label='Enable buckets', value=True
|
label='Enable buckets', value=True
|
||||||
)
|
)
|
||||||
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
|
cache_latent_input = gr.Checkbox(
|
||||||
|
label='Cache latent', value=True
|
||||||
|
)
|
||||||
use_8bit_adam_input = gr.Checkbox(
|
use_8bit_adam_input = gr.Checkbox(
|
||||||
label='Use 8bit adam', value=True
|
label='Use 8bit adam', value=True
|
||||||
)
|
)
|
||||||
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
||||||
|
with gr.Accordion('Advanced Configuration', open=False):
|
||||||
|
with gr.Row():
|
||||||
|
full_fp16_input = gr.Checkbox(
|
||||||
|
label='Full fp16 training (experimental)', value=False
|
||||||
|
)
|
||||||
|
no_token_padding_input = gr.Checkbox(
|
||||||
|
label='No token padding', value=False
|
||||||
|
)
|
||||||
|
|
||||||
|
gradient_checkpointing_input = gr.Checkbox(
|
||||||
|
label='Gradient checkpointing', value=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shuffle_caption = gr.Checkbox(
|
||||||
|
label='Shuffle caption', value=False
|
||||||
|
)
|
||||||
|
save_state = gr.Checkbox(label='Save state', value=False)
|
||||||
|
with gr.Row():
|
||||||
|
resume = gr.Textbox(
|
||||||
|
label='Resume',
|
||||||
|
placeholder='path to "last-state" state folder to resume from',
|
||||||
|
)
|
||||||
|
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
||||||
|
resume_button.click(get_folder_path, outputs=resume)
|
||||||
|
prior_loss_weight = gr.Number(
|
||||||
|
label='Prior loss weight', value=1.0
|
||||||
|
)
|
||||||
|
|
||||||
button_run = gr.Button('Train model')
|
button_run = gr.Button('Train model')
|
||||||
|
|
||||||
with gr.Tab('Utilities'):
|
with gr.Tab('Utilities'):
|
||||||
@ -713,8 +762,6 @@ with interface:
|
|||||||
gradio_dataset_balancing_tab()
|
gradio_dataset_balancing_tab()
|
||||||
gradio_convert_model_tab()
|
gradio_convert_model_tab()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
open_configuration,
|
open_configuration,
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -746,7 +793,11 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
save_model_as_dropdown
|
save_model_as_dropdown,
|
||||||
|
shuffle_caption,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
prior_loss_weight,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
config_file_name,
|
config_file_name,
|
||||||
@ -777,7 +828,11 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
save_model_as_dropdown
|
save_model_as_dropdown,
|
||||||
|
shuffle_caption,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
prior_loss_weight,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -815,7 +870,11 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
save_model_as_dropdown
|
save_model_as_dropdown,
|
||||||
|
shuffle_caption,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
prior_loss_weight,
|
||||||
],
|
],
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
@ -852,7 +911,11 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
save_model_as_dropdown
|
save_model_as_dropdown,
|
||||||
|
shuffle_caption,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
prior_loss_weight,
|
||||||
],
|
],
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
@ -887,7 +950,11 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
save_model_as_dropdown
|
save_model_as_dropdown,
|
||||||
|
shuffle_caption,
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
prior_loss_weight,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,37 +1,52 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from easygui import msgbox
|
from easygui import msgbox
|
||||||
import subprocess
|
import subprocess
|
||||||
from .common_gui import get_folder_path
|
from .common_gui import get_folder_path, add_pre_postfix
|
||||||
|
|
||||||
|
|
||||||
def caption_images(
|
def caption_images(
|
||||||
caption_text_input, images_dir_input, overwrite_input, caption_file_ext
|
caption_text_input,
|
||||||
|
images_dir_input,
|
||||||
|
overwrite_input,
|
||||||
|
caption_file_ext,
|
||||||
|
prefix,
|
||||||
|
postfix,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
|
||||||
if caption_text_input == '':
|
|
||||||
msgbox('Caption text is missing...')
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check for images_dir_input
|
# Check for images_dir_input
|
||||||
if images_dir_input == '':
|
if images_dir_input == '':
|
||||||
msgbox('Image folder is missing...')
|
msgbox('Image folder is missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
print(
|
if not caption_text_input == '':
|
||||||
f'Captioning files in {images_dir_input} with {caption_text_input}...'
|
print(
|
||||||
)
|
f'Captioning files in {images_dir_input} with {caption_text_input}...'
|
||||||
run_cmd = f'python "tools/caption.py"'
|
)
|
||||||
run_cmd += f' --caption_text="{caption_text_input}"'
|
run_cmd = f'python "tools/caption.py"'
|
||||||
|
run_cmd += f' --caption_text="{caption_text_input}"'
|
||||||
|
if overwrite_input:
|
||||||
|
run_cmd += f' --overwrite'
|
||||||
|
if caption_file_ext != '':
|
||||||
|
run_cmd += f' --caption_file_ext="{caption_file_ext}"'
|
||||||
|
run_cmd += f' "{images_dir_input}"'
|
||||||
|
|
||||||
|
print(run_cmd)
|
||||||
|
|
||||||
|
# Run the command
|
||||||
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
if overwrite_input:
|
if overwrite_input:
|
||||||
run_cmd += f' --overwrite'
|
# Add prefix and postfix
|
||||||
if caption_file_ext != '':
|
add_pre_postfix(
|
||||||
run_cmd += f' --caption_file_ext="{caption_file_ext}"'
|
folder=images_dir_input,
|
||||||
run_cmd += f' "{images_dir_input}"'
|
caption_file_ext=caption_file_ext,
|
||||||
|
prefix=prefix,
|
||||||
print(run_cmd)
|
postfix=postfix,
|
||||||
|
)
|
||||||
# Run the command
|
else:
|
||||||
subprocess.run(run_cmd)
|
if not prefix == '' or not postfix == '':
|
||||||
|
msgbox(
|
||||||
|
'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
|
||||||
|
)
|
||||||
|
|
||||||
print('...captioning done')
|
print('...captioning done')
|
||||||
|
|
||||||
@ -46,22 +61,6 @@ def gradio_basic_caption_gui_tab():
|
|||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'This utility will allow the creation of simple caption files for each images in a folder.'
|
'This utility will allow the creation of simple caption files for each images in a folder.'
|
||||||
)
|
)
|
||||||
with gr.Row():
|
|
||||||
caption_text_input = gr.Textbox(
|
|
||||||
label='Caption text',
|
|
||||||
placeholder='Eg: , by some artist',
|
|
||||||
interactive=True,
|
|
||||||
)
|
|
||||||
overwrite_input = gr.Checkbox(
|
|
||||||
label='Overwrite existing captions in folder',
|
|
||||||
interactive=True,
|
|
||||||
value=False,
|
|
||||||
)
|
|
||||||
caption_file_ext = gr.Textbox(
|
|
||||||
label='Caption file extension',
|
|
||||||
placeholder='(Optional) Default: .caption',
|
|
||||||
interactive=True,
|
|
||||||
)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
images_dir_input = gr.Textbox(
|
images_dir_input = gr.Textbox(
|
||||||
label='Image folder to caption',
|
label='Image folder to caption',
|
||||||
@ -74,6 +73,33 @@ def gradio_basic_caption_gui_tab():
|
|||||||
button_images_dir_input.click(
|
button_images_dir_input.click(
|
||||||
get_folder_path, outputs=images_dir_input
|
get_folder_path, outputs=images_dir_input
|
||||||
)
|
)
|
||||||
|
with gr.Row():
|
||||||
|
prefix = gr.Textbox(
|
||||||
|
label='Prefix to add to txt caption',
|
||||||
|
placeholder='(Optional)',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
caption_text_input = gr.Textbox(
|
||||||
|
label='Caption text',
|
||||||
|
placeholder='Eg: , by some artist. Leave empti if you just want to add pre or postfix',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
postfix = gr.Textbox(
|
||||||
|
label='Postfix to add to txt caption',
|
||||||
|
placeholder='(Optional)',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
overwrite_input = gr.Checkbox(
|
||||||
|
label='Overwrite existing captions in folder',
|
||||||
|
interactive=True,
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
|
caption_file_ext = gr.Textbox(
|
||||||
|
label='Caption file extension',
|
||||||
|
placeholder='(Optional) Default: .caption',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
caption_button = gr.Button('Caption images')
|
caption_button = gr.Button('Caption images')
|
||||||
|
|
||||||
caption_button.click(
|
caption_button.click(
|
||||||
@ -83,5 +109,7 @@ def gradio_basic_caption_gui_tab():
|
|||||||
images_dir_input,
|
images_dir_input,
|
||||||
overwrite_input,
|
overwrite_input,
|
||||||
caption_file_ext,
|
caption_file_ext,
|
||||||
|
prefix,
|
||||||
|
postfix,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from easygui import msgbox
|
from easygui import msgbox
|
||||||
import subprocess
|
import subprocess
|
||||||
from .common_gui import get_folder_path
|
import os
|
||||||
|
from .common_gui import get_folder_path, add_pre_postfix
|
||||||
|
|
||||||
|
|
||||||
def caption_images(
|
def caption_images(
|
||||||
@ -13,6 +14,8 @@ def caption_images(
|
|||||||
max_length,
|
max_length,
|
||||||
min_length,
|
min_length,
|
||||||
beam_search,
|
beam_search,
|
||||||
|
prefix,
|
||||||
|
postfix,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
# if caption_text_input == "":
|
# if caption_text_input == "":
|
||||||
@ -43,6 +46,14 @@ def caption_images(
|
|||||||
# Run the command
|
# Run the command
|
||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
|
# Add prefix and postfix
|
||||||
|
add_pre_postfix(
|
||||||
|
folder=train_data_dir,
|
||||||
|
caption_file_ext=caption_file_ext,
|
||||||
|
prefix=prefix,
|
||||||
|
postfix=postfix,
|
||||||
|
)
|
||||||
|
|
||||||
print('...captioning done')
|
print('...captioning done')
|
||||||
|
|
||||||
|
|
||||||
@ -68,13 +79,25 @@ def gradio_blip_caption_gui_tab():
|
|||||||
button_train_data_dir_input.click(
|
button_train_data_dir_input.click(
|
||||||
get_folder_path, outputs=train_data_dir
|
get_folder_path, outputs=train_data_dir
|
||||||
)
|
)
|
||||||
|
with gr.Row():
|
||||||
caption_file_ext = gr.Textbox(
|
caption_file_ext = gr.Textbox(
|
||||||
label='Caption file extension',
|
label='Caption file extension',
|
||||||
placeholder='(Optional) Default: .caption',
|
placeholder='(Optional) Default: .caption',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefix = gr.Textbox(
|
||||||
|
label='Prefix to add to BLIP caption',
|
||||||
|
placeholder='(Optional)',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
postfix = gr.Textbox(
|
||||||
|
label='Postfix to add to BLIP caption',
|
||||||
|
placeholder='(Optional)',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
batch_size = gr.Number(
|
batch_size = gr.Number(
|
||||||
value=1, label='Batch size', interactive=True
|
value=1, label='Batch size', interactive=True
|
||||||
)
|
)
|
||||||
@ -107,5 +130,7 @@ def gradio_blip_caption_gui_tab():
|
|||||||
max_length,
|
max_length,
|
||||||
min_length,
|
min_length,
|
||||||
beam_search,
|
beam_search,
|
||||||
|
prefix,
|
||||||
|
postfix,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,15 +1,20 @@
|
|||||||
from tkinter import filedialog, Tk
|
from tkinter import filedialog, Tk
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def get_file_path(file_path='', defaultextension='.json'):
|
def get_file_path(file_path='', defaultextension='.json'):
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
# print(f'current file path: {current_file_path}')
|
# print(f'current file path: {current_file_path}')
|
||||||
|
|
||||||
root = Tk()
|
root = Tk()
|
||||||
root.wm_attributes('-topmost', 1)
|
root.wm_attributes('-topmost', 1)
|
||||||
root.withdraw()
|
root.withdraw()
|
||||||
file_path = filedialog.askopenfilename(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension)
|
file_path = filedialog.askopenfilename(
|
||||||
|
filetypes=(('Config files', '*.json'), ('All files', '*')),
|
||||||
|
defaultextension=defaultextension,
|
||||||
|
)
|
||||||
root.destroy()
|
root.destroy()
|
||||||
|
|
||||||
if file_path == '':
|
if file_path == '':
|
||||||
file_path = current_file_path
|
file_path = current_file_path
|
||||||
|
|
||||||
@ -25,35 +30,58 @@ def remove_doublequote(file_path):
|
|||||||
|
|
||||||
def get_folder_path(folder_path=''):
|
def get_folder_path(folder_path=''):
|
||||||
current_folder_path = folder_path
|
current_folder_path = folder_path
|
||||||
|
|
||||||
root = Tk()
|
root = Tk()
|
||||||
root.wm_attributes('-topmost', 1)
|
root.wm_attributes('-topmost', 1)
|
||||||
root.withdraw()
|
root.withdraw()
|
||||||
folder_path = filedialog.askdirectory()
|
folder_path = filedialog.askdirectory()
|
||||||
root.destroy()
|
root.destroy()
|
||||||
|
|
||||||
if folder_path == '':
|
if folder_path == '':
|
||||||
folder_path = current_folder_path
|
folder_path = current_folder_path
|
||||||
|
|
||||||
return folder_path
|
return folder_path
|
||||||
|
|
||||||
|
|
||||||
def get_saveasfile_path(file_path='', defaultextension='.json'):
|
def get_saveasfile_path(file_path='', defaultextension='.json'):
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
# print(f'current file path: {current_file_path}')
|
# print(f'current file path: {current_file_path}')
|
||||||
|
|
||||||
root = Tk()
|
root = Tk()
|
||||||
root.wm_attributes('-topmost', 1)
|
root.wm_attributes('-topmost', 1)
|
||||||
root.withdraw()
|
root.withdraw()
|
||||||
save_file_path = filedialog.asksaveasfile(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension)
|
save_file_path = filedialog.asksaveasfile(
|
||||||
|
filetypes=(('Config files', '*.json'), ('All files', '*')),
|
||||||
|
defaultextension=defaultextension,
|
||||||
|
)
|
||||||
root.destroy()
|
root.destroy()
|
||||||
|
|
||||||
# file_path = file_path.name
|
# print(save_file_path)
|
||||||
if file_path == '':
|
|
||||||
|
if save_file_path == None:
|
||||||
file_path = current_file_path
|
file_path = current_file_path
|
||||||
else:
|
else:
|
||||||
print(save_file_path.name)
|
print(save_file_path.name)
|
||||||
file_path = save_file_path.name
|
file_path = save_file_path.name
|
||||||
|
|
||||||
print(file_path)
|
# print(file_path)
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def add_pre_postfix(
|
||||||
|
folder='', prefix='', postfix='', caption_file_ext='.caption'
|
||||||
|
):
|
||||||
|
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
|
||||||
|
if not prefix == '':
|
||||||
|
prefix = f'{prefix} '
|
||||||
|
if not postfix == '':
|
||||||
|
postfix = f' {postfix}'
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
with open(os.path.join(folder, file), 'r+') as f:
|
||||||
|
content = f.read()
|
||||||
|
content = content.rstrip()
|
||||||
|
f.seek(0, 0)
|
||||||
|
f.write(f'{prefix}{content}{postfix}')
|
||||||
|
f.close()
|
||||||
|
@ -8,37 +8,45 @@ from .common_gui import get_folder_path, get_file_path
|
|||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
def convert_model(source_model_input, source_model_type, target_model_folder_input, target_model_name_input, target_model_type, target_save_precision_type):
|
|
||||||
|
def convert_model(
|
||||||
|
source_model_input,
|
||||||
|
source_model_type,
|
||||||
|
target_model_folder_input,
|
||||||
|
target_model_name_input,
|
||||||
|
target_model_type,
|
||||||
|
target_save_precision_type,
|
||||||
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if source_model_type == "":
|
if source_model_type == '':
|
||||||
msgbox("Invalid source model type")
|
msgbox('Invalid source model type')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
if os.path.isfile(source_model_input):
|
if os.path.isfile(source_model_input):
|
||||||
print('The provided source model is a file')
|
print('The provided source model is a file')
|
||||||
elif os.path.isdir(source_model_input):
|
elif os.path.isdir(source_model_input):
|
||||||
print('The provided model is a folder')
|
print('The provided model is a folder')
|
||||||
else:
|
else:
|
||||||
msgbox("The provided source model is neither a file nor a folder")
|
msgbox('The provided source model is neither a file nor a folder')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if source model exist
|
# Check if source model exist
|
||||||
if os.path.isdir(target_model_folder_input):
|
if os.path.isdir(target_model_folder_input):
|
||||||
print('The provided model folder exist')
|
print('The provided model folder exist')
|
||||||
else:
|
else:
|
||||||
msgbox("The provided target folder does not exist")
|
msgbox('The provided target folder does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd = f'.\\venv\Scripts\python.exe "tools/convert_diffusers20_original_sd.py"'
|
run_cmd = f'.\\venv\Scripts\python.exe "tools/convert_diffusers20_original_sd.py"'
|
||||||
|
|
||||||
v1_models = [
|
v1_models = [
|
||||||
'runwayml/stable-diffusion-v1-5',
|
'runwayml/stable-diffusion-v1-5',
|
||||||
'CompVis/stable-diffusion-v1-4',
|
'CompVis/stable-diffusion-v1-4',
|
||||||
]
|
]
|
||||||
|
|
||||||
# check if v1 models
|
# check if v1 models
|
||||||
if str(source_model_type) in v1_models:
|
if str(source_model_type) in v1_models:
|
||||||
print('SD v1 model specified. Setting --v1 parameter')
|
print('SD v1 model specified. Setting --v1 parameter')
|
||||||
@ -46,54 +54,76 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
|
|||||||
else:
|
else:
|
||||||
print('SD v2 model specified. Setting --v2 parameter')
|
print('SD v2 model specified. Setting --v2 parameter')
|
||||||
run_cmd += ' --v2'
|
run_cmd += ' --v2'
|
||||||
|
|
||||||
if not target_save_precision_type == 'unspecified':
|
if not target_save_precision_type == 'unspecified':
|
||||||
run_cmd += f' --{target_save_precision_type}'
|
run_cmd += f' --{target_save_precision_type}'
|
||||||
|
|
||||||
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
|
if (
|
||||||
|
target_model_type == 'diffuser'
|
||||||
|
or target_model_type == 'diffuser_safetensors'
|
||||||
|
):
|
||||||
run_cmd += f' --reference_model="{source_model_type}"'
|
run_cmd += f' --reference_model="{source_model_type}"'
|
||||||
|
|
||||||
if target_model_type == 'diffuser_safetensors':
|
if target_model_type == 'diffuser_safetensors':
|
||||||
run_cmd += ' --use_safetensors'
|
run_cmd += ' --use_safetensors'
|
||||||
|
|
||||||
run_cmd += f' "{source_model_input}"'
|
run_cmd += f' "{source_model_input}"'
|
||||||
|
|
||||||
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
|
if (
|
||||||
target_model_path = os.path.join(target_model_folder_input, target_model_name_input)
|
target_model_type == 'diffuser'
|
||||||
|
or target_model_type == 'diffuser_safetensors'
|
||||||
|
):
|
||||||
|
target_model_path = os.path.join(
|
||||||
|
target_model_folder_input, target_model_name_input
|
||||||
|
)
|
||||||
run_cmd += f' "{target_model_path}"'
|
run_cmd += f' "{target_model_path}"'
|
||||||
else:
|
else:
|
||||||
target_model_path = os.path.join(target_model_folder_input, f'{target_model_name_input}.{target_model_type}')
|
target_model_path = os.path.join(
|
||||||
|
target_model_folder_input,
|
||||||
|
f'{target_model_name_input}.{target_model_type}',
|
||||||
|
)
|
||||||
run_cmd += f' "{target_model_path}"'
|
run_cmd += f' "{target_model_path}"'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
|
|
||||||
# Run the command
|
# Run the command
|
||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
if not target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
|
if (
|
||||||
|
not target_model_type == 'diffuser'
|
||||||
v2_models = ['stabilityai/stable-diffusion-2-1-base',
|
or target_model_type == 'diffuser_safetensors'
|
||||||
'stabilityai/stable-diffusion-2-base',]
|
):
|
||||||
v_parameterization =[
|
|
||||||
'stabilityai/stable-diffusion-2-1',
|
v2_models = [
|
||||||
'stabilityai/stable-diffusion-2',]
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
|
'stabilityai/stable-diffusion-2-base',
|
||||||
|
]
|
||||||
|
v_parameterization = [
|
||||||
|
'stabilityai/stable-diffusion-2-1',
|
||||||
|
'stabilityai/stable-diffusion-2',
|
||||||
|
]
|
||||||
|
|
||||||
if str(source_model_type) in v2_models:
|
if str(source_model_type) in v2_models:
|
||||||
inference_file = os.path.join(target_model_folder_input, f'{target_model_name_input}.yaml')
|
inference_file = os.path.join(
|
||||||
|
target_model_folder_input, f'{target_model_name_input}.yaml'
|
||||||
|
)
|
||||||
print(f'Saving v2-inference.yaml as {inference_file}')
|
print(f'Saving v2-inference.yaml as {inference_file}')
|
||||||
shutil.copy(
|
shutil.copy(
|
||||||
f'./v2_inference/v2-inference.yaml',
|
f'./v2_inference/v2-inference.yaml',
|
||||||
f'{inference_file}',
|
f'{inference_file}',
|
||||||
)
|
)
|
||||||
|
|
||||||
if str(source_model_type) in v_parameterization:
|
if str(source_model_type) in v_parameterization:
|
||||||
inference_file = os.path.join(target_model_folder_input, f'{target_model_name_input}.yaml')
|
inference_file = os.path.join(
|
||||||
|
target_model_folder_input, f'{target_model_name_input}.yaml'
|
||||||
|
)
|
||||||
print(f'Saving v2-inference-v.yaml as {inference_file}')
|
print(f'Saving v2-inference-v.yaml as {inference_file}')
|
||||||
shutil.copy(
|
shutil.copy(
|
||||||
f'./v2_inference/v2-inference-v.yaml',
|
f'./v2_inference/v2-inference-v.yaml',
|
||||||
f'{inference_file}',
|
f'{inference_file}',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# parser = argparse.ArgumentParser()
|
# parser = argparse.ArgumentParser()
|
||||||
# parser.add_argument("--v1", action='store_true',
|
# parser.add_argument("--v1", action='store_true',
|
||||||
# help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
# help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
||||||
@ -138,22 +168,27 @@ def gradio_convert_model_tab():
|
|||||||
button_source_model_dir.click(
|
button_source_model_dir.click(
|
||||||
get_folder_path, outputs=source_model_input
|
get_folder_path, outputs=source_model_input
|
||||||
)
|
)
|
||||||
|
|
||||||
button_source_model_file = gr.Button(
|
button_source_model_file = gr.Button(
|
||||||
document_symbol, elem_id='open_folder_small'
|
document_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_source_model_file.click(
|
button_source_model_file.click(
|
||||||
get_file_path, inputs=[source_model_input], outputs=source_model_input
|
get_file_path,
|
||||||
|
inputs=[source_model_input],
|
||||||
|
outputs=source_model_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
source_model_type = gr.Dropdown(label="Source model type", choices=[
|
source_model_type = gr.Dropdown(
|
||||||
|
label='Source model type',
|
||||||
|
choices=[
|
||||||
'stabilityai/stable-diffusion-2-1-base',
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
'stabilityai/stable-diffusion-2-base',
|
'stabilityai/stable-diffusion-2-base',
|
||||||
'stabilityai/stable-diffusion-2-1',
|
'stabilityai/stable-diffusion-2-1',
|
||||||
'stabilityai/stable-diffusion-2',
|
'stabilityai/stable-diffusion-2',
|
||||||
'runwayml/stable-diffusion-v1-5',
|
'runwayml/stable-diffusion-v1-5',
|
||||||
'CompVis/stable-diffusion-v1-4',
|
'CompVis/stable-diffusion-v1-4',
|
||||||
],)
|
],
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
target_model_folder_input = gr.Textbox(
|
target_model_folder_input = gr.Textbox(
|
||||||
label='Target model folder',
|
label='Target model folder',
|
||||||
@ -166,30 +201,37 @@ def gradio_convert_model_tab():
|
|||||||
button_target_model_folder.click(
|
button_target_model_folder.click(
|
||||||
get_folder_path, outputs=target_model_folder_input
|
get_folder_path, outputs=target_model_folder_input
|
||||||
)
|
)
|
||||||
|
|
||||||
target_model_name_input = gr.Textbox(
|
target_model_name_input = gr.Textbox(
|
||||||
label='Target model name',
|
label='Target model name',
|
||||||
placeholder='target model name...',
|
placeholder='target model name...',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
target_model_type = gr.Dropdown(label="Target model type", choices=[
|
target_model_type = gr.Dropdown(
|
||||||
|
label='Target model type',
|
||||||
|
choices=[
|
||||||
'diffuser',
|
'diffuser',
|
||||||
'diffuser_safetensors',
|
'diffuser_safetensors',
|
||||||
'ckpt',
|
'ckpt',
|
||||||
'safetensors',
|
'safetensors',
|
||||||
],)
|
],
|
||||||
target_save_precision_type = gr.Dropdown(label="Target model precison", choices=[
|
)
|
||||||
'unspecified',
|
target_save_precision_type = gr.Dropdown(
|
||||||
'fp16',
|
label='Target model precison',
|
||||||
'bf16',
|
choices=['unspecified', 'fp16', 'bf16', 'float'],
|
||||||
'float'
|
value='unspecified',
|
||||||
], value='unspecified')
|
)
|
||||||
|
|
||||||
|
|
||||||
convert_button = gr.Button('Convert model')
|
convert_button = gr.Button('Convert model')
|
||||||
|
|
||||||
convert_button.click(
|
convert_button.click(
|
||||||
convert_model,
|
convert_model,
|
||||||
inputs=[source_model_input, source_model_type, target_model_folder_input, target_model_name_input, target_model_type, target_save_precision_type
|
inputs=[
|
||||||
|
source_model_input,
|
||||||
|
source_model_type,
|
||||||
|
target_model_folder_input,
|
||||||
|
target_model_name_input,
|
||||||
|
target_model_type,
|
||||||
|
target_save_precision_type,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -13,7 +13,7 @@ from .common_gui import get_folder_path
|
|||||||
|
|
||||||
|
|
||||||
def dataset_balancing(concept_repeats, folder, insecure):
|
def dataset_balancing(concept_repeats, folder, insecure):
|
||||||
|
|
||||||
if not concept_repeats > 0:
|
if not concept_repeats > 0:
|
||||||
# Display an error message if the total number of repeats is not a valid integer
|
# Display an error message if the total number of repeats is not a valid integer
|
||||||
msgbox('Please enter a valid integer for the total number of repeats.')
|
msgbox('Please enter a valid integer for the total number of repeats.')
|
||||||
@ -72,23 +72,35 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
|||||||
|
|
||||||
os.rename(old_name, new_name)
|
os.rename(old_name, new_name)
|
||||||
else:
|
else:
|
||||||
print(f"Skipping folder {subdir} because it does not match kohya_ss expected syntax...")
|
print(
|
||||||
|
f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...'
|
||||||
|
)
|
||||||
|
|
||||||
msgbox('Dataset balancing completed...')
|
msgbox('Dataset balancing completed...')
|
||||||
|
|
||||||
|
|
||||||
def warning(insecure):
|
def warning(insecure):
|
||||||
if insecure:
|
if insecure:
|
||||||
if boolbox(f'WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?', choices=("Yes, I like danger", "No, get me out of here")):
|
if boolbox(
|
||||||
|
f'WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?',
|
||||||
|
choices=('Yes, I like danger', 'No, get me out of here'),
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def gradio_dataset_balancing_tab():
|
def gradio_dataset_balancing_tab():
|
||||||
with gr.Tab('Dataset balancing'):
|
with gr.Tab('Dataset balancing'):
|
||||||
gr.Markdown('This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.')
|
gr.Markdown(
|
||||||
gr.Markdown('WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!')
|
'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.'
|
||||||
|
)
|
||||||
|
gr.Markdown(
|
||||||
|
'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!'
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
select_dataset_folder_input = gr.Textbox(label="Dataset folder",
|
select_dataset_folder_input = gr.Textbox(
|
||||||
|
label='Dataset folder',
|
||||||
placeholder='Folder containing the concepts folders to balance...',
|
placeholder='Folder containing the concepts folders to balance...',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
@ -106,10 +118,17 @@ def gradio_dataset_balancing_tab():
|
|||||||
label='Training steps per concept per epoch',
|
label='Training steps per concept per epoch',
|
||||||
)
|
)
|
||||||
with gr.Accordion('Advanced options', open=False):
|
with gr.Accordion('Advanced options', open=False):
|
||||||
insecure = gr.Checkbox(value=False, label="DANGER!!! -- Insecure folder renaming -- DANGER!!!")
|
insecure = gr.Checkbox(
|
||||||
|
value=False,
|
||||||
|
label='DANGER!!! -- Insecure folder renaming -- DANGER!!!',
|
||||||
|
)
|
||||||
insecure.change(warning, inputs=insecure, outputs=insecure)
|
insecure.change(warning, inputs=insecure, outputs=insecure)
|
||||||
balance_button = gr.Button('Balance dataset')
|
balance_button = gr.Button('Balance dataset')
|
||||||
balance_button.click(
|
balance_button.click(
|
||||||
dataset_balancing,
|
dataset_balancing,
|
||||||
inputs=[total_repeats_number, select_dataset_folder_input, insecure],
|
inputs=[
|
||||||
|
total_repeats_number,
|
||||||
|
select_dataset_folder_input,
|
||||||
|
insecure,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user