Move save_state and resume to common gui

Format code
This commit is contained in:
bmaltais 2023-01-15 12:01:59 -05:00
parent 6aed2bb402
commit abccecb093
4 changed files with 179 additions and 172 deletions

View File

@ -334,8 +334,8 @@ def train_model(
run_cmd += ' --xformers'
if shuffle_caption:
run_cmd += ' --shuffle_caption'
if save_state:
run_cmd += ' --save_state'
# if save_state:
# run_cmd += ' --save_state'
if color_aug:
run_cmd += ' --color_aug'
if flip_aug:
@ -368,8 +368,8 @@ def train_model(
)
if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}'
if not resume == '':
run_cmd += f' --resume={resume}'
# if not resume == '':
# run_cmd += f' --resume={resume}'
if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
if int(clip_skip) > 1:
@ -384,7 +384,13 @@ def train_model(
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
if not max_data_loader_n_workers == '':
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
run_cmd += run_cmd_advanced_training(
max_train_epochs=max_train_epochs,
max_data_loader_n_workers=max_data_loader_n_workers,
max_token_length=max_token_length,
resume=resume,
save_state=save_state,
)
print(run_cmd)
# Run the command
@ -681,9 +687,6 @@ def dreambooth_tab(
label='Shuffle caption', value=False
)
with gr.Row():
save_state = gr.Checkbox(
label='Save training state', value=False
)
color_aug = gr.Checkbox(
label='Color augmentation', value=False
)
@ -697,12 +700,6 @@ def dreambooth_tab(
label='Clip skip', value='1', minimum=1, maximum=12, step=1
)
with gr.Row():
resume = gr.Textbox(
label='Resume from saved training state',
placeholder='path to "last-state" state folder to resume from',
)
resume_button = gr.Button('📂', elem_id='open_folder_small')
resume_button.click(get_folder_path, outputs=resume)
prior_loss_weight = gr.Number(
label='Prior loss weight', value=1.0
)
@ -712,25 +709,7 @@ def dreambooth_tab(
)
vae_button = gr.Button('📂', elem_id='open_folder_small')
vae_button.click(get_any_file_path, outputs=vae)
max_token_length = gr.Dropdown(
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
# with gr.Row():
# max_train_epochs = gr.Textbox(
# label='Max train epoch',
# placeholder='(Optional) Override number of epoch',
# )
# max_data_loader_n_workers = gr.Textbox(
# label='Max num workers for DataLoader',
# placeholder='(Optional) Override number of epoch. Default: 8',
# )
save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
with gr.Tab('Tools'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'

View File

@ -335,15 +335,21 @@ def train_model(
run_cmd += f' --clip_skip={str(clip_skip)}'
if int(gradient_accumulation_steps) > 1:
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
if save_state:
run_cmd += ' --save_state'
if not resume == '':
run_cmd += f' --resume={resume}'
# if save_state:
# run_cmd += ' --save_state'
# if not resume == '':
# run_cmd += f' --resume={resume}'
if not output_name == '':
run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}'
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
run_cmd += run_cmd_advanced_training(
max_train_epochs=max_train_epochs,
max_data_loader_n_workers=max_data_loader_n_workers,
max_token_length=max_token_length,
resume=resume,
save_state=save_state,
)
print(run_cmd)
# Run the command
@ -640,31 +646,13 @@ def finetune_tab():
label='Shuffle caption', value=False
)
with gr.Row():
save_state = gr.Checkbox(
label='Save training state', value=False
)
resume = gr.Textbox(
label='Resume from saved training state',
placeholder='path to "last-state" state folder to resume from',
)
resume_button = gr.Button('📂', elem_id='open_folder_small')
resume_button.click(get_folder_path, outputs=resume)
gradient_checkpointing = gr.Checkbox(
label='Gradient checkpointing', value=False
)
gradient_accumulation_steps = gr.Number(
label='Gradient accumulate steps', value='1'
)
max_token_length = gr.Dropdown(
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
with gr.Box():
with gr.Row():
create_caption = gr.Checkbox(

View File

@ -4,10 +4,12 @@ import gradio as gr
from easygui import msgbox
import shutil
def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name)
def has_ext_files(directory, extension):
# Iterate through all the files in the directory
for file in os.listdir(directory):
@ -17,7 +19,10 @@ def has_ext_files(directory, extension):
# If no extension files were found, return False
return False
def get_file_path(file_path='', defaultextension='.json', extension_name='Config files'):
def get_file_path(
file_path='', defaultextension='.json', extension_name='Config files'
):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
@ -27,8 +32,13 @@ def get_file_path(file_path='', defaultextension='.json', extension_name='Config
root.wm_attributes('-topmost', 1)
root.withdraw()
file_path = filedialog.askopenfilename(
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')),
defaultextension=defaultextension, initialfile=initial_file, initialdir=initial_dir
filetypes=(
(f'{extension_name}', f'{defaultextension}'),
('All files', '*'),
),
defaultextension=defaultextension,
initialfile=initial_file,
initialdir=initial_dir,
)
root.destroy()
@ -37,6 +47,7 @@ def get_file_path(file_path='', defaultextension='.json', extension_name='Config
return file_path
def get_any_file_path(file_path=''):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
@ -46,8 +57,10 @@ def get_any_file_path(file_path=''):
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
file_path = filedialog.askopenfilename(initialdir=initial_dir,
initialfile=initial_file,)
file_path = filedialog.askopenfilename(
initialdir=initial_dir,
initialfile=initial_file,
)
root.destroy()
if file_path == '':
@ -80,7 +93,9 @@ def get_folder_path(folder_path=''):
return folder_path
def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='Config files'):
def get_saveasfile_path(
file_path='', defaultextension='.json', extension_name='Config files'
):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
@ -90,7 +105,10 @@ def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='
root.wm_attributes('-topmost', 1)
root.withdraw()
save_file_path = filedialog.asksaveasfile(
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')),
filetypes=(
(f'{extension_name}', f'{defaultextension}'),
('All files', '*'),
),
defaultextension=defaultextension,
initialdir=initial_dir,
initialfile=initial_file,
@ -109,7 +127,10 @@ def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='
return file_path
def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config files'):
def get_saveasfilename_path(
file_path='', extensions='*', extension_name='Config files'
):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
@ -118,7 +139,8 @@ def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
save_file_path = filedialog.asksaveasfilename(filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
save_file_path = filedialog.asksaveasfilename(
filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
defaultextension=extensions,
initialdir=initial_dir,
initialfile=initial_file,
@ -138,7 +160,9 @@ def add_pre_postfix(
folder='', prefix='', postfix='', caption_file_ext='.caption'
):
if not has_ext_files(folder, caption_file_ext):
msgbox(f'No files with extension {caption_file_ext} were found in {folder}...')
msgbox(
f'No files with extension {caption_file_ext} were found in {folder}...'
)
return
if prefix == '' and postfix == '':
@ -158,12 +182,13 @@ def add_pre_postfix(
f.write(f'{prefix}{content}{postfix}')
f.close()
def find_replace(
folder='', caption_file_ext='.caption', find='', replace=''
):
def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
print('Running caption find/replace')
if not has_ext_files(folder, caption_file_ext):
msgbox(f'No files with extension {caption_file_ext} were found in {folder}...')
msgbox(
f'No files with extension {caption_file_ext} were found in {folder}...'
)
return
if find == '':
@ -179,13 +204,17 @@ def find_replace(
f.write(content)
f.close()
def color_aug_changed(color_aug):
if color_aug:
msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...')
msgbox(
'Disabling "Cache latent" because "Color augmentation" has been selected...'
)
return gr.Checkbox.update(value=False, interactive=False)
else:
return gr.Checkbox.update(value=True, interactive=True)
def save_inference_file(output_dir, v2, v_parameterization, output_name):
# List all files in the directory
files = os.listdir(output_dir)
@ -201,18 +230,23 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
# Copy the v2-inference-v.yaml file to the current file, with a .yaml extension
if v2 and v_parameterization:
print(f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml')
print(
f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml'
)
shutil.copy(
f'./v2_inference/v2-inference-v.yaml',
f'{output_dir}/{file_name}.yaml',
)
elif v2:
print(f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml')
print(
f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml'
)
shutil.copy(
f'./v2_inference/v2-inference.yaml',
f'{output_dir}/{file_name}.yaml',
)
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
# define a list of substrings to search for
substrings_v2 = [
@ -267,7 +301,25 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
### Gradio common GUI section
###
def gradio_advanced_training():
with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False)
resume = gr.Textbox(
label='Resume from saved training state',
placeholder='path to "last-state" state folder to resume from',
)
resume_button = gr.Button('📂', elem_id='open_folder_small')
resume_button.click(get_folder_path, outputs=resume)
max_token_length = gr.Dropdown(
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
with gr.Row():
max_train_epochs = gr.Textbox(
label='Max train epoch',
@ -277,15 +329,30 @@ def gradio_advanced_training():
label='Max num workers for DataLoader',
placeholder='(Optional) Override number of epoch. Default: 8',
)
return max_train_epochs, max_data_loader_n_workers
return (
save_state,
resume,
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
)
def run_cmd_advanced_training(**kwargs):
run_cmd = ''
max_train_epochs = kwargs.get('max_train_epochs', '')
max_data_loader_n_workers = kwargs.get('max_data_loader_n_workers', '')
if not max_train_epochs == '':
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
if not max_data_loader_n_workers == '':
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
options = [
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
if kwargs.get('max_train_epochs')
else '',
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
if kwargs.get('max_data_loader_n_workers')
else '',
f' --max_token_length={kwargs.get("max_token_length", "")}'
if int(kwargs.get('max_token_length', 0)) > 75
else '',
f' --resume="{kwargs.get("resume", "")}"'
if kwargs.get('resume')
else '',
' --save_state' if kwargs.get('save_state') else '',
]
run_cmd = ''.join(options)
return run_cmd

View File

@ -19,7 +19,9 @@ from library.common_gui import (
get_saveasfile_path,
color_aug_changed,
save_inference_file,
set_pretrained_model_name_or_path_input, gradio_advanced_training,run_cmd_advanced_training,
set_pretrained_model_name_or_path_input,
gradio_advanced_training,
run_cmd_advanced_training,
)
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
@ -180,7 +182,7 @@ def open_configuration(
# load variables from JSON file
with open(file_path, 'r') as f:
my_data = json.load(f)
print("Loading config...")
print('Loading config...')
else:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
my_data = {}
@ -350,8 +352,8 @@ def train_model(
run_cmd += ' --xformers'
if shuffle_caption:
run_cmd += ' --shuffle_caption'
if save_state:
run_cmd += ' --save_state'
# if save_state:
# run_cmd += ' --save_state'
if color_aug:
run_cmd += ' --color_aug'
if flip_aug:
@ -386,8 +388,8 @@ def train_model(
)
if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}'
if not resume == '':
run_cmd += f' --resume="{resume}"'
# if not resume == '':
# run_cmd += f' --resume="{resume}"'
if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
run_cmd += f' --network_module=networks.lora'
@ -414,9 +416,15 @@ def train_model(
# run_cmd += f' --vae="{vae}"'
if not output_name == '':
run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}'
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
# if (int(max_token_length) > 75):
# run_cmd += f' --max_token_length={max_token_length}'
run_cmd += run_cmd_advanced_training(
max_train_epochs=max_train_epochs,
max_data_loader_n_workers=max_data_loader_n_workers,
max_token_length=max_token_length,
resume=resume,
save_state=save_state,
)
print(run_cmd)
# Run the command
@ -564,9 +572,7 @@ def lora_tab(
label='Image folder',
placeholder='Folder where the training folders containing the images are located',
)
train_data_dir_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
train_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
train_data_dir_folder.click(
get_folder_path, outputs=train_data_dir
)
@ -574,33 +580,21 @@ def lora_tab(
label='Regularisation folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
)
reg_data_dir_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
reg_data_dir_folder.click(
get_folder_path, outputs=reg_data_dir
)
reg_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
reg_data_dir_folder.click(get_folder_path, outputs=reg_data_dir)
with gr.Row():
output_dir = gr.Textbox(
label='Output folder',
placeholder='Folder to output trained model',
)
output_dir_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
output_dir_folder.click(
get_folder_path, outputs=output_dir
)
output_dir_folder = gr.Button('📂', elem_id='open_folder_small')
output_dir_folder.click(get_folder_path, outputs=output_dir)
logging_dir = gr.Textbox(
label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder',
)
logging_dir_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
logging_dir_folder.click(
get_folder_path, outputs=logging_dir
)
logging_dir_folder = gr.Button('📂', elem_id='open_folder_small')
logging_dir_folder.click(get_folder_path, outputs=logging_dir)
with gr.Row():
output_name = gr.Textbox(
label='Model output name',
@ -659,11 +653,13 @@ def lora_tab(
with gr.Row():
text_encoder_lr = gr.Textbox(
label='Text Encoder learning rate',
value="5e-5",
value='5e-5',
placeholder='Optional',
)
unet_lr = gr.Textbox(
label='Unet learning rate', value="1e-3", placeholder='Optional'
label='Unet learning rate',
value='1e-3',
placeholder='Optional',
)
network_dim = gr.Slider(
minimum=1,
@ -731,13 +727,9 @@ def lora_tab(
label='Stop text encoder training',
)
with gr.Row():
enable_bucket = gr.Checkbox(
label='Enable buckets', value=True
)
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
cache_latent = gr.Checkbox(label='Cache latent', value=True)
use_8bit_adam = gr.Checkbox(
label='Use 8bit adam', value=True
)
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True)
with gr.Accordion('Advanced Configuration', open=False):
with gr.Row():
@ -777,32 +769,13 @@ def lora_tab(
mem_eff_attn = gr.Checkbox(
label='Memory efficient attention', value=False
)
with gr.Row():
save_state = gr.Checkbox(
label='Save training state', value=False
)
resume = gr.Textbox(
label='Resume from saved training state',
placeholder='path to "last-state" state folder to resume from',
)
resume_button = gr.Button('📂', elem_id='open_folder_small')
resume_button.click(get_folder_path, outputs=resume)
# vae = gr.Textbox(
# label='VAE',
# placeholder='(Optiona) path to checkpoint of vae to replace for training',
# )
# vae_button = gr.Button('📂', elem_id='open_folder_small')
# vae_button.click(get_any_file_path, outputs=vae)
max_token_length = gr.Dropdown(
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
(
save_state,
resume,
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
) = gradio_advanced_training()
with gr.Tab('Tools'):
gr.Markdown(