Add support for:

shuffle_caption,
save_state,
resume,
prior_loss_weight,

Fix issue with config open and save
This commit is contained in:
bmaltais 2022-12-19 21:50:05 -05:00
parent 1d412726b3
commit 1f1dd5c4de
7 changed files with 1594 additions and 1073 deletions

View File

@ -10,7 +10,9 @@ import os
import subprocess
import pathlib
import shutil
from library.dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.basic_caption_gui import gradio_basic_caption_gui_tab
from library.convert_model_gui import gradio_convert_model_tab
from library.blip_caption_gui import gradio_blip_caption_gui_tab
@ -20,7 +22,7 @@ from library.common_gui import (
get_folder_path,
remove_doublequote,
get_file_path,
get_saveasfile_path
get_saveasfile_path,
)
from easygui import msgbox
@ -60,7 +62,11 @@ def save_configuration(
stop_text_encoder_training,
use_8bit_adam,
xformers,
save_model_as
save_model_as,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
):
original_file_path = file_path
@ -68,22 +74,14 @@ def save_configuration(
if save_as_bool:
print('Save as...')
# file_path = filesavebox(
# 'Select the config file to save',
# default='finetune.json',
# filetypes='*.json',
# )
file_path = get_saveasfile_path(file_path)
else:
print('Save...')
if file_path == None or file_path == '':
# file_path = filesavebox(
# 'Select the config file to save',
# default='finetune.json',
# filetypes='*.json',
# )
file_path = get_saveasfile_path(file_path)
# print(file_path)
if file_path == None or file_path == '':
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
@ -116,7 +114,11 @@ def save_configuration(
'stop_text_encoder_training': stop_text_encoder_training,
'use_8bit_adam': use_8bit_adam,
'xformers': xformers,
'save_model_as': save_model_as
'save_model_as': save_model_as,
'shuffle_caption': shuffle_caption,
'save_state': save_state,
'resume': resume,
'prior_loss_weight': prior_loss_weight,
}
# Save the data to the selected file
@ -155,14 +157,18 @@ def open_configuration(
stop_text_encoder_training,
use_8bit_adam,
xformers,
save_model_as
save_model_as,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
):
original_file_path = file_path
file_path = get_file_path(file_path)
# print(file_path)
if file_path != '' and file_path != None:
print(file_path)
if not file_path == '' and not file_path == None:
# load variables from JSON file
with open(file_path, 'r') as f:
my_data = json.load(f)
@ -204,7 +210,11 @@ def open_configuration(
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
my_data.get('use_8bit_adam', use_8bit_adam),
my_data.get('xformers', xformers),
my_data.get('save_model_as', save_model_as)
my_data.get('save_model_as', save_model_as),
my_data.get('shuffle_caption', shuffle_caption),
my_data.get('save_state', save_state),
my_data.get('resume', resume),
my_data.get('prior_loss_weight', prior_loss_weight),
)
@ -236,7 +246,11 @@ def train_model(
stop_text_encoder_training_pct,
use_8bit_adam,
xformers,
save_model_as
save_model_as,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
):
def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required
@ -360,6 +374,10 @@ def train_model(
run_cmd += ' --use_8bit_adam'
if xformers:
run_cmd += ' --xformers'
if shuffle_caption:
run_cmd += ' --shuffle_caption'
if save_state:
run_cmd += ' --save_state'
run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
)
@ -382,9 +400,15 @@ def train_model(
run_cmd += f' --logging_dir={logging_dir}'
run_cmd += f' --caption_extention={caption_extention}'
if not stop_text_encoder_training == 0:
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
run_cmd += (
f' --stop_text_encoder_training={stop_text_encoder_training}'
)
if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}'
if not resume == '':
run_cmd += f' --resume={resume}'
if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
print(run_cmd)
# Run the command
@ -472,8 +496,8 @@ with interface:
)
config_file_name = gr.Textbox(
label='',
# placeholder="type the configuration file path or use the 'Open' button above to select it...",
interactive=False
placeholder="type the configuration file path or use the 'Open' button above to select it...",
interactive=True,
)
# config_file_name.change(
# remove_doublequote,
@ -491,13 +515,16 @@ with interface:
document_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_fille.click(
get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input
get_file_path,
inputs=[pretrained_model_name_or_path_input],
outputs=pretrained_model_name_or_path_input,
)
pretrained_model_name_or_path_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_folder.click(
get_folder_path, outputs=pretrained_model_name_or_path_input
get_folder_path,
outputs=pretrained_model_name_or_path_input,
)
model_list = gr.Dropdown(
label='(Optional) Model Quick Pick',
@ -517,10 +544,10 @@ with interface:
'same as source model',
'ckpt',
'diffusers',
"diffusers_safetensors",
'diffusers_safetensors',
'safetensors',
],
value='same as source model'
value='same as source model',
)
with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True)
@ -607,7 +634,9 @@ with interface:
)
with gr.Tab('Training parameters'):
with gr.Row():
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
learning_rate_input = gr.Textbox(
label='Learning rate', value=1e-6
)
lr_scheduler_input = gr.Dropdown(
label='LR Scheduler',
choices=[
@ -662,7 +691,9 @@ with interface:
with gr.Row():
seed_input = gr.Textbox(label='Seed', value=1234)
max_resolution_input = gr.Textbox(
label='Max resolution', value='512,512', placeholder='512,512'
label='Max resolution',
value='512,512',
placeholder='512,512',
)
with gr.Row():
caption_extention_input = gr.Textbox(
@ -676,6 +707,18 @@ with interface:
step=1,
label='Stop text encoder training',
)
with gr.Row():
enable_bucket_input = gr.Checkbox(
label='Enable buckets', value=True
)
cache_latent_input = gr.Checkbox(
label='Cache latent', value=True
)
use_8bit_adam_input = gr.Checkbox(
label='Use 8bit adam', value=True
)
xformers_input = gr.Checkbox(label='Use xformers', value=True)
with gr.Accordion('Advanced Configuration', open=False):
with gr.Row():
full_fp16_input = gr.Checkbox(
label='Full fp16 training (experimental)', value=False
@ -687,15 +730,21 @@ with interface:
gradient_checkpointing_input = gr.Checkbox(
label='Gradient checkpointing', value=False
)
shuffle_caption = gr.Checkbox(
label='Shuffle caption', value=False
)
save_state = gr.Checkbox(label='Save state', value=False)
with gr.Row():
enable_bucket_input = gr.Checkbox(
label='Enable buckets', value=True
resume = gr.Textbox(
label='Resume',
placeholder='path to "last-state" state folder to resume from',
)
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
use_8bit_adam_input = gr.Checkbox(
label='Use 8bit adam', value=True
resume_button = gr.Button('📂', elem_id='open_folder_small')
resume_button.click(get_folder_path, outputs=resume)
prior_loss_weight = gr.Number(
label='Prior loss weight', value=1.0
)
xformers_input = gr.Checkbox(label='Use xformers', value=True)
button_run = gr.Button('Train model')
@ -713,8 +762,6 @@ with interface:
gradio_dataset_balancing_tab()
gradio_convert_model_tab()
button_open_config.click(
open_configuration,
inputs=[
@ -746,7 +793,11 @@ with interface:
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
outputs=[
config_file_name,
@ -777,7 +828,11 @@ with interface:
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
)
@ -815,7 +870,11 @@ with interface:
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
outputs=[config_file_name],
)
@ -852,7 +911,11 @@ with interface:
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
outputs=[config_file_name],
)
@ -887,7 +950,11 @@ with interface:
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
)

View File

@ -5,7 +5,12 @@ from .common_gui import get_folder_path, add_pre_postfix
def caption_images(
caption_text_input, images_dir_input, overwrite_input, caption_file_ext, prefix, postfix
caption_text_input,
images_dir_input,
overwrite_input,
caption_file_ext,
prefix,
postfix,
):
# Check for images_dir_input
if images_dir_input == '':
@ -31,10 +36,17 @@ def caption_images(
if overwrite_input:
# Add prefix and postfix
add_pre_postfix(folder=images_dir_input, caption_file_ext=caption_file_ext, prefix=prefix, postfix=postfix)
add_pre_postfix(
folder=images_dir_input,
caption_file_ext=caption_file_ext,
prefix=prefix,
postfix=postfix,
)
else:
if not prefix == '' or not postfix == '':
msgbox('Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...')
msgbox(
'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
)
print('...captioning done')
@ -97,6 +109,7 @@ def gradio_basic_caption_gui_tab():
images_dir_input,
overwrite_input,
caption_file_ext,
prefix, postfix
prefix,
postfix,
],
)

View File

@ -4,6 +4,7 @@ import subprocess
import os
from .common_gui import get_folder_path, add_pre_postfix
def caption_images(
train_data_dir,
caption_file_ext,
@ -14,7 +15,7 @@ def caption_images(
min_length,
beam_search,
prefix,
postfix
postfix,
):
# Check for caption_text_input
# if caption_text_input == "":
@ -46,7 +47,12 @@ def caption_images(
subprocess.run(run_cmd)
# Add prefix and postfix
add_pre_postfix(folder=train_data_dir, caption_file_ext=caption_file_ext, prefix=prefix, postfix=postfix)
add_pre_postfix(
folder=train_data_dir,
caption_file_ext=caption_file_ext,
prefix=prefix,
postfix=postfix,
)
print('...captioning done')
@ -125,6 +131,6 @@ def gradio_blip_caption_gui_tab():
min_length,
beam_search,
prefix,
postfix
postfix,
],
)

View File

@ -1,6 +1,7 @@
from tkinter import filedialog, Tk
import os
def get_file_path(file_path='', defaultextension='.json'):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
@ -8,7 +9,10 @@ def get_file_path(file_path='', defaultextension='.json'):
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
file_path = filedialog.askopenfilename(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension)
file_path = filedialog.askopenfilename(
filetypes=(('Config files', '*.json'), ('All files', '*')),
defaultextension=defaultextension,
)
root.destroy()
if file_path == '':
@ -38,6 +42,7 @@ def get_folder_path(folder_path=''):
return folder_path
def get_saveasfile_path(file_path='', defaultextension='.json'):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
@ -45,21 +50,28 @@ def get_saveasfile_path(file_path='', defaultextension='.json'):
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
save_file_path = filedialog.asksaveasfile(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension)
save_file_path = filedialog.asksaveasfile(
filetypes=(('Config files', '*.json'), ('All files', '*')),
defaultextension=defaultextension,
)
root.destroy()
# file_path = file_path.name
if file_path == '':
# print(save_file_path)
if save_file_path == None:
file_path = current_file_path
else:
print(save_file_path.name)
file_path = save_file_path.name
print(file_path)
# print(file_path)
return file_path
def add_pre_postfix(folder='', prefix='', postfix='', caption_file_ext='.caption'):
def add_pre_postfix(
folder='', prefix='', postfix='', caption_file_ext='.caption'
):
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
if not prefix == '':
prefix = f'{prefix} '
@ -70,6 +82,6 @@ def add_pre_postfix(folder='', prefix='', postfix='', caption_file_ext='.caption
with open(os.path.join(folder, file), 'r+') as f:
content = f.read()
content = content.rstrip()
f.seek(0,0)
f.seek(0, 0)
f.write(f'{prefix}{content}{postfix}')
f.close()

View File

@ -10,10 +10,18 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
def convert_model(source_model_input, source_model_type, target_model_folder_input, target_model_name_input, target_model_type, target_save_precision_type):
def convert_model(
source_model_input,
source_model_type,
target_model_folder_input,
target_model_name_input,
target_model_type,
target_save_precision_type,
):
# Check for caption_text_input
if source_model_type == "":
msgbox("Invalid source model type")
if source_model_type == '':
msgbox('Invalid source model type')
return
# Check if source model exist
@ -22,14 +30,14 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
elif os.path.isdir(source_model_input):
print('The provided model is a folder')
else:
msgbox("The provided source model is neither a file nor a folder")
msgbox('The provided source model is neither a file nor a folder')
return
# Check if source model exist
if os.path.isdir(target_model_folder_input):
print('The provided model folder exist')
else:
msgbox("The provided target folder does not exist")
msgbox('The provided target folder does not exist')
return
run_cmd = f'.\\venv\Scripts\python.exe "tools/convert_diffusers20_original_sd.py"'
@ -50,7 +58,10 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
if not target_save_precision_type == 'unspecified':
run_cmd += f' --{target_save_precision_type}'
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
if (
target_model_type == 'diffuser'
or target_model_type == 'diffuser_safetensors'
):
run_cmd += f' --reference_model="{source_model_type}"'
if target_model_type == 'diffuser_safetensors':
@ -58,11 +69,19 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
run_cmd += f' "{source_model_input}"'
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
target_model_path = os.path.join(target_model_folder_input, target_model_name_input)
if (
target_model_type == 'diffuser'
or target_model_type == 'diffuser_safetensors'
):
target_model_path = os.path.join(
target_model_folder_input, target_model_name_input
)
run_cmd += f' "{target_model_path}"'
else:
target_model_path = os.path.join(target_model_folder_input, f'{target_model_name_input}.{target_model_type}')
target_model_path = os.path.join(
target_model_folder_input,
f'{target_model_name_input}.{target_model_type}',
)
run_cmd += f' "{target_model_path}"'
print(run_cmd)
@ -70,16 +89,24 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
# Run the command
subprocess.run(run_cmd)
if not target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
if (
not target_model_type == 'diffuser'
or target_model_type == 'diffuser_safetensors'
):
v2_models = ['stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',]
v_parameterization =[
v2_models = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
]
v_parameterization = [
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',]
'stabilityai/stable-diffusion-2',
]
if str(source_model_type) in v2_models:
inference_file = os.path.join(target_model_folder_input, f'{target_model_name_input}.yaml')
inference_file = os.path.join(
target_model_folder_input, f'{target_model_name_input}.yaml'
)
print(f'Saving v2-inference.yaml as {inference_file}')
shutil.copy(
f'./v2_inference/v2-inference.yaml',
@ -87,13 +114,16 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
)
if str(source_model_type) in v_parameterization:
inference_file = os.path.join(target_model_folder_input, f'{target_model_name_input}.yaml')
inference_file = os.path.join(
target_model_folder_input, f'{target_model_name_input}.yaml'
)
print(f'Saving v2-inference-v.yaml as {inference_file}')
shutil.copy(
f'./v2_inference/v2-inference-v.yaml',
f'{inference_file}',
)
# parser = argparse.ArgumentParser()
# parser.add_argument("--v1", action='store_true',
# help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
@ -143,17 +173,22 @@ def gradio_convert_model_tab():
document_symbol, elem_id='open_folder_small'
)
button_source_model_file.click(
get_file_path, inputs=[source_model_input], outputs=source_model_input
get_file_path,
inputs=[source_model_input],
outputs=source_model_input,
)
source_model_type = gr.Dropdown(label="Source model type", choices=[
source_model_type = gr.Dropdown(
label='Source model type',
choices=[
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
],)
],
)
with gr.Row():
target_model_folder_input = gr.Textbox(
label='Target model folder',
@ -172,24 +207,31 @@ def gradio_convert_model_tab():
placeholder='target model name...',
interactive=True,
)
target_model_type = gr.Dropdown(label="Target model type", choices=[
target_model_type = gr.Dropdown(
label='Target model type',
choices=[
'diffuser',
'diffuser_safetensors',
'ckpt',
'safetensors',
],)
target_save_precision_type = gr.Dropdown(label="Target model precison", choices=[
'unspecified',
'fp16',
'bf16',
'float'
], value='unspecified')
],
)
target_save_precision_type = gr.Dropdown(
label='Target model precison',
choices=['unspecified', 'fp16', 'bf16', 'float'],
value='unspecified',
)
convert_button = gr.Button('Convert model')
convert_button.click(
convert_model,
inputs=[source_model_input, source_model_type, target_model_folder_input, target_model_name_input, target_model_type, target_save_precision_type
inputs=[
source_model_input,
source_model_type,
target_model_folder_input,
target_model_name_input,
target_model_type,
target_save_precision_type,
],
)

View File

@ -72,23 +72,35 @@ def dataset_balancing(concept_repeats, folder, insecure):
os.rename(old_name, new_name)
else:
print(f"Skipping folder {subdir} because it does not match kohya_ss expected syntax...")
print(
f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...'
)
msgbox('Dataset balancing completed...')
def warning(insecure):
if insecure:
if boolbox(f'WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?', choices=("Yes, I like danger", "No, get me out of here")):
if boolbox(
f'WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?',
choices=('Yes, I like danger', 'No, get me out of here'),
):
return True
else:
return False
def gradio_dataset_balancing_tab():
with gr.Tab('Dataset balancing'):
gr.Markdown('This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.')
gr.Markdown('WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!')
gr.Markdown(
'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.'
)
gr.Markdown(
'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!'
)
with gr.Row():
select_dataset_folder_input = gr.Textbox(label="Dataset folder",
select_dataset_folder_input = gr.Textbox(
label='Dataset folder',
placeholder='Folder containing the concepts folders to balance...',
interactive=True,
)
@ -106,10 +118,17 @@ def gradio_dataset_balancing_tab():
label='Training steps per concept per epoch',
)
with gr.Accordion('Advanced options', open=False):
insecure = gr.Checkbox(value=False, label="DANGER!!! -- Insecure folder renaming -- DANGER!!!")
insecure = gr.Checkbox(
value=False,
label='DANGER!!! -- Insecure folder renaming -- DANGER!!!',
)
insecure.change(warning, inputs=insecure, outputs=insecure)
balance_button = gr.Button('Balance dataset')
balance_button.click(
dataset_balancing,
inputs=[total_repeats_number, select_dataset_folder_input, insecure],
inputs=[
total_repeats_number,
select_dataset_folder_input,
insecure,
],
)

File diff suppressed because it is too large Load Diff