Move save_state and resume to common gui

Format code
This commit is contained in:
bmaltais 2023-01-15 12:01:59 -05:00
parent 6aed2bb402
commit abccecb093
4 changed files with 179 additions and 172 deletions

View File

@ -334,8 +334,8 @@ def train_model(
run_cmd += ' --xformers' run_cmd += ' --xformers'
if shuffle_caption: if shuffle_caption:
run_cmd += ' --shuffle_caption' run_cmd += ' --shuffle_caption'
if save_state: # if save_state:
run_cmd += ' --save_state' # run_cmd += ' --save_state'
if color_aug: if color_aug:
run_cmd += ' --color_aug' run_cmd += ' --color_aug'
if flip_aug: if flip_aug:
@ -368,8 +368,8 @@ def train_model(
) )
if not save_model_as == 'same as source model': if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}' run_cmd += f' --save_model_as={save_model_as}'
if not resume == '': # if not resume == '':
run_cmd += f' --resume={resume}' # run_cmd += f' --resume={resume}'
if not float(prior_loss_weight) == 1.0: if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --prior_loss_weight={prior_loss_weight}'
if int(clip_skip) > 1: if int(clip_skip) > 1:
@ -384,7 +384,13 @@ def train_model(
run_cmd += f' --max_train_epochs="{max_train_epochs}"' run_cmd += f' --max_train_epochs="{max_train_epochs}"'
if not max_data_loader_n_workers == '': if not max_data_loader_n_workers == '':
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers) run_cmd += run_cmd_advanced_training(
max_train_epochs=max_train_epochs,
max_data_loader_n_workers=max_data_loader_n_workers,
max_token_length=max_token_length,
resume=resume,
save_state=save_state,
)
print(run_cmd) print(run_cmd)
# Run the command # Run the command
@ -681,9 +687,6 @@ def dreambooth_tab(
label='Shuffle caption', value=False label='Shuffle caption', value=False
) )
with gr.Row(): with gr.Row():
save_state = gr.Checkbox(
label='Save training state', value=False
)
color_aug = gr.Checkbox( color_aug = gr.Checkbox(
label='Color augmentation', value=False label='Color augmentation', value=False
) )
@ -697,12 +700,6 @@ def dreambooth_tab(
label='Clip skip', value='1', minimum=1, maximum=12, step=1 label='Clip skip', value='1', minimum=1, maximum=12, step=1
) )
with gr.Row(): with gr.Row():
resume = gr.Textbox(
label='Resume from saved training state',
placeholder='path to "last-state" state folder to resume from',
)
resume_button = gr.Button('📂', elem_id='open_folder_small')
resume_button.click(get_folder_path, outputs=resume)
prior_loss_weight = gr.Number( prior_loss_weight = gr.Number(
label='Prior loss weight', value=1.0 label='Prior loss weight', value=1.0
) )
@ -712,25 +709,7 @@ def dreambooth_tab(
) )
vae_button = gr.Button('📂', elem_id='open_folder_small') vae_button = gr.Button('📂', elem_id='open_folder_small')
vae_button.click(get_any_file_path, outputs=vae) vae_button.click(get_any_file_path, outputs=vae)
max_token_length = gr.Dropdown( save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
# with gr.Row():
# max_train_epochs = gr.Textbox(
# label='Max train epoch',
# placeholder='(Optional) Override number of epoch',
# )
# max_data_loader_n_workers = gr.Textbox(
# label='Max num workers for DataLoader',
# placeholder='(Optional) Override number of epoch. Default: 8',
# )
with gr.Tab('Tools'): with gr.Tab('Tools'):
gr.Markdown( gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...' 'This section provide Dreambooth tools to help setup your dataset...'

View File

@ -335,15 +335,21 @@ def train_model(
run_cmd += f' --clip_skip={str(clip_skip)}' run_cmd += f' --clip_skip={str(clip_skip)}'
if int(gradient_accumulation_steps) > 1: if int(gradient_accumulation_steps) > 1:
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
if save_state: # if save_state:
run_cmd += ' --save_state' # run_cmd += ' --save_state'
if not resume == '': # if not resume == '':
run_cmd += f' --resume={resume}' # run_cmd += f' --resume={resume}'
if not output_name == '': if not output_name == '':
run_cmd += f' --output_name="{output_name}"' run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75): if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}' run_cmd += f' --max_token_length={max_token_length}'
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers) run_cmd += run_cmd_advanced_training(
max_train_epochs=max_train_epochs,
max_data_loader_n_workers=max_data_loader_n_workers,
max_token_length=max_token_length,
resume=resume,
save_state=save_state,
)
print(run_cmd) print(run_cmd)
# Run the command # Run the command
@ -640,31 +646,13 @@ def finetune_tab():
label='Shuffle caption', value=False label='Shuffle caption', value=False
) )
with gr.Row(): with gr.Row():
save_state = gr.Checkbox(
label='Save training state', value=False
)
resume = gr.Textbox(
label='Resume from saved training state',
placeholder='path to "last-state" state folder to resume from',
)
resume_button = gr.Button('📂', elem_id='open_folder_small')
resume_button.click(get_folder_path, outputs=resume)
gradient_checkpointing = gr.Checkbox( gradient_checkpointing = gr.Checkbox(
label='Gradient checkpointing', value=False label='Gradient checkpointing', value=False
) )
gradient_accumulation_steps = gr.Number( gradient_accumulation_steps = gr.Number(
label='Gradient accumulate steps', value='1' label='Gradient accumulate steps', value='1'
) )
max_token_length = gr.Dropdown( save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
create_caption = gr.Checkbox( create_caption = gr.Checkbox(

View File

@ -4,10 +4,12 @@ import gradio as gr
from easygui import msgbox from easygui import msgbox
import shutil import shutil
def get_dir_and_file(file_path): def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path) dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name) return (dir_path, file_name)
def has_ext_files(directory, extension): def has_ext_files(directory, extension):
# Iterate through all the files in the directory # Iterate through all the files in the directory
for file in os.listdir(directory): for file in os.listdir(directory):
@ -17,18 +19,26 @@ def has_ext_files(directory, extension):
# If no extension files were found, return False # If no extension files were found, return False
return False return False
def get_file_path(file_path='', defaultextension='.json', extension_name='Config files'):
def get_file_path(
file_path='', defaultextension='.json', extension_name='Config files'
):
current_file_path = file_path current_file_path = file_path
# print(f'current file path: {current_file_path}') # print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path) initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
file_path = filedialog.askopenfilename( file_path = filedialog.askopenfilename(
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')), filetypes=(
defaultextension=defaultextension, initialfile=initial_file, initialdir=initial_dir (f'{extension_name}', f'{defaultextension}'),
('All files', '*'),
),
defaultextension=defaultextension,
initialfile=initial_file,
initialdir=initial_dir,
) )
root.destroy() root.destroy()
@ -37,17 +47,20 @@ def get_file_path(file_path='', defaultextension='.json', extension_name='Config
return file_path return file_path
def get_any_file_path(file_path=''): def get_any_file_path(file_path=''):
current_file_path = file_path current_file_path = file_path
# print(f'current file path: {current_file_path}') # print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path) initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
file_path = filedialog.askopenfilename(initialdir=initial_dir, file_path = filedialog.askopenfilename(
initialfile=initial_file,) initialdir=initial_dir,
initialfile=initial_file,
)
root.destroy() root.destroy()
if file_path == '': if file_path == '':
@ -65,7 +78,7 @@ def remove_doublequote(file_path):
def get_folder_path(folder_path=''): def get_folder_path(folder_path=''):
current_folder_path = folder_path current_folder_path = folder_path
initial_dir, initial_file = get_dir_and_file(folder_path) initial_dir, initial_file = get_dir_and_file(folder_path)
root = Tk() root = Tk()
@ -80,17 +93,22 @@ def get_folder_path(folder_path=''):
return folder_path return folder_path
def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='Config files'): def get_saveasfile_path(
file_path='', defaultextension='.json', extension_name='Config files'
):
current_file_path = file_path current_file_path = file_path
# print(f'current file path: {current_file_path}') # print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path) initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
save_file_path = filedialog.asksaveasfile( save_file_path = filedialog.asksaveasfile(
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')), filetypes=(
(f'{extension_name}', f'{defaultextension}'),
('All files', '*'),
),
defaultextension=defaultextension, defaultextension=defaultextension,
initialdir=initial_dir, initialdir=initial_dir,
initialfile=initial_file, initialfile=initial_file,
@ -109,16 +127,20 @@ def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='
return file_path return file_path
def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config files'):
def get_saveasfilename_path(
file_path='', extensions='*', extension_name='Config files'
):
current_file_path = file_path current_file_path = file_path
# print(f'current file path: {current_file_path}') # print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path) initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
save_file_path = filedialog.asksaveasfilename(filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')), save_file_path = filedialog.asksaveasfilename(
filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
defaultextension=extensions, defaultextension=extensions,
initialdir=initial_dir, initialdir=initial_dir,
initialfile=initial_file, initialfile=initial_file,
@ -138,9 +160,11 @@ def add_pre_postfix(
folder='', prefix='', postfix='', caption_file_ext='.caption' folder='', prefix='', postfix='', caption_file_ext='.caption'
): ):
if not has_ext_files(folder, caption_file_ext): if not has_ext_files(folder, caption_file_ext):
msgbox(f'No files with extension {caption_file_ext} were found in {folder}...') msgbox(
f'No files with extension {caption_file_ext} were found in {folder}...'
)
return return
if prefix == '' and postfix == '': if prefix == '' and postfix == '':
return return
@ -157,15 +181,16 @@ def add_pre_postfix(
f.seek(0, 0) f.seek(0, 0)
f.write(f'{prefix}{content}{postfix}') f.write(f'{prefix}{content}{postfix}')
f.close() f.close()
def find_replace(
folder='', caption_file_ext='.caption', find='', replace='' def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
):
print('Running caption find/replace') print('Running caption find/replace')
if not has_ext_files(folder, caption_file_ext): if not has_ext_files(folder, caption_file_ext):
msgbox(f'No files with extension {caption_file_ext} were found in {folder}...') msgbox(
f'No files with extension {caption_file_ext} were found in {folder}...'
)
return return
if find == '': if find == '':
return return
@ -179,13 +204,17 @@ def find_replace(
f.write(content) f.write(content)
f.close() f.close()
def color_aug_changed(color_aug): def color_aug_changed(color_aug):
if color_aug: if color_aug:
msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...') msgbox(
'Disabling "Cache latent" because "Color augmentation" has been selected...'
)
return gr.Checkbox.update(value=False, interactive=False) return gr.Checkbox.update(value=False, interactive=False)
else: else:
return gr.Checkbox.update(value=True, interactive=True) return gr.Checkbox.update(value=True, interactive=True)
def save_inference_file(output_dir, v2, v_parameterization, output_name): def save_inference_file(output_dir, v2, v_parameterization, output_name):
# List all files in the directory # List all files in the directory
files = os.listdir(output_dir) files = os.listdir(output_dir)
@ -198,21 +227,26 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
if os.path.isfile(os.path.join(output_dir, file)): if os.path.isfile(os.path.join(output_dir, file)):
# Split the file name and extension # Split the file name and extension
file_name, ext = os.path.splitext(file) file_name, ext = os.path.splitext(file)
# Copy the v2-inference-v.yaml file to the current file, with a .yaml extension # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension
if v2 and v_parameterization: if v2 and v_parameterization:
print(f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml') print(
f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml'
)
shutil.copy( shutil.copy(
f'./v2_inference/v2-inference-v.yaml', f'./v2_inference/v2-inference-v.yaml',
f'{output_dir}/{file_name}.yaml', f'{output_dir}/{file_name}.yaml',
) )
elif v2: elif v2:
print(f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml') print(
f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml'
)
shutil.copy( shutil.copy(
f'./v2_inference/v2-inference.yaml', f'./v2_inference/v2-inference.yaml',
f'{output_dir}/{file_name}.yaml', f'{output_dir}/{file_name}.yaml',
) )
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
# define a list of substrings to search for # define a list of substrings to search for
substrings_v2 = [ substrings_v2 = [
@ -262,30 +296,63 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
v_parameterization = False v_parameterization = False
return value, v2, v_parameterization return value, v2, v_parameterization
### ###
### Gradio common GUI section ### Gradio common GUI section
### ###
def gradio_advanced_training(): def gradio_advanced_training():
with gr.Row(): with gr.Row():
max_train_epochs = gr.Textbox( save_state = gr.Checkbox(label='Save training state', value=False)
label='Max train epoch', resume = gr.Textbox(
placeholder='(Optional) Override number of epoch', label='Resume from saved training state',
) placeholder='path to "last-state" state folder to resume from',
max_data_loader_n_workers = gr.Textbox( )
label='Max num workers for DataLoader', resume_button = gr.Button('📂', elem_id='open_folder_small')
placeholder='(Optional) Override number of epoch. Default: 8', resume_button.click(get_folder_path, outputs=resume)
) max_token_length = gr.Dropdown(
return max_train_epochs, max_data_loader_n_workers label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
with gr.Row():
max_train_epochs = gr.Textbox(
label='Max train epoch',
placeholder='(Optional) Override number of epoch',
)
max_data_loader_n_workers = gr.Textbox(
label='Max num workers for DataLoader',
placeholder='(Optional) Override number of epoch. Default: 8',
)
return (
save_state,
resume,
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
)
def run_cmd_advanced_training(**kwargs): def run_cmd_advanced_training(**kwargs):
run_cmd = '' options = [
max_train_epochs = kwargs.get('max_train_epochs', '') f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
max_data_loader_n_workers = kwargs.get('max_data_loader_n_workers', '') if kwargs.get('max_train_epochs')
if not max_train_epochs == '': else '',
run_cmd += f' --max_train_epochs="{max_train_epochs}"' f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
if not max_data_loader_n_workers == '': if kwargs.get('max_data_loader_n_workers')
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' else '',
f' --max_token_length={kwargs.get("max_token_length", "")}'
return run_cmd if int(kwargs.get('max_token_length', 0)) > 75
else '',
f' --resume="{kwargs.get("resume", "")}"'
if kwargs.get('resume')
else '',
' --save_state' if kwargs.get('save_state') else '',
]
run_cmd = ''.join(options)
return run_cmd

View File

@ -19,7 +19,9 @@ from library.common_gui import (
get_saveasfile_path, get_saveasfile_path,
color_aug_changed, color_aug_changed,
save_inference_file, save_inference_file,
set_pretrained_model_name_or_path_input, gradio_advanced_training,run_cmd_advanced_training, set_pretrained_model_name_or_path_input,
gradio_advanced_training,
run_cmd_advanced_training,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -172,7 +174,7 @@ def open_configuration(
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
original_file_path = file_path original_file_path = file_path
file_path = get_file_path(file_path) file_path = get_file_path(file_path)
@ -180,11 +182,11 @@ def open_configuration(
# load variables from JSON file # load variables from JSON file
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
my_data = json.load(f) my_data = json.load(f)
print("Loading config...") print('Loading config...')
else: else:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
my_data = {} my_data = {}
values = [file_path] values = [file_path]
for key, value in parameters: for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
@ -235,7 +237,7 @@ def train_model(
gradient_accumulation_steps, gradient_accumulation_steps,
mem_eff_attn, mem_eff_attn,
output_name, output_name,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
@ -350,8 +352,8 @@ def train_model(
run_cmd += ' --xformers' run_cmd += ' --xformers'
if shuffle_caption: if shuffle_caption:
run_cmd += ' --shuffle_caption' run_cmd += ' --shuffle_caption'
if save_state: # if save_state:
run_cmd += ' --save_state' # run_cmd += ' --save_state'
if color_aug: if color_aug:
run_cmd += ' --color_aug' run_cmd += ' --color_aug'
if flip_aug: if flip_aug:
@ -386,8 +388,8 @@ def train_model(
) )
if not save_model_as == 'same as source model': if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}' run_cmd += f' --save_model_as={save_model_as}'
if not resume == '': # if not resume == '':
run_cmd += f' --resume="{resume}"' # run_cmd += f' --resume="{resume}"'
if not float(prior_loss_weight) == 1.0: if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --prior_loss_weight={prior_loss_weight}'
run_cmd += f' --network_module=networks.lora' run_cmd += f' --network_module=networks.lora'
@ -414,9 +416,15 @@ def train_model(
# run_cmd += f' --vae="{vae}"' # run_cmd += f' --vae="{vae}"'
if not output_name == '': if not output_name == '':
run_cmd += f' --output_name="{output_name}"' run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75): # if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}' # run_cmd += f' --max_token_length={max_token_length}'
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers) run_cmd += run_cmd_advanced_training(
max_train_epochs=max_train_epochs,
max_data_loader_n_workers=max_data_loader_n_workers,
max_token_length=max_token_length,
resume=resume,
save_state=save_state,
)
print(run_cmd) print(run_cmd)
# Run the command # Run the command
@ -564,9 +572,7 @@ def lora_tab(
label='Image folder', label='Image folder',
placeholder='Folder where the training folders containing the images are located', placeholder='Folder where the training folders containing the images are located',
) )
train_data_dir_folder = gr.Button( train_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
'📂', elem_id='open_folder_small'
)
train_data_dir_folder.click( train_data_dir_folder.click(
get_folder_path, outputs=train_data_dir get_folder_path, outputs=train_data_dir
) )
@ -574,33 +580,21 @@ def lora_tab(
label='Regularisation folder', label='Regularisation folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located', placeholder='(Optional) Folder where where the regularization folders containing the images are located',
) )
reg_data_dir_folder = gr.Button( reg_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
'📂', elem_id='open_folder_small' reg_data_dir_folder.click(get_folder_path, outputs=reg_data_dir)
)
reg_data_dir_folder.click(
get_folder_path, outputs=reg_data_dir
)
with gr.Row(): with gr.Row():
output_dir = gr.Textbox( output_dir = gr.Textbox(
label='Output folder', label='Output folder',
placeholder='Folder to output trained model', placeholder='Folder to output trained model',
) )
output_dir_folder = gr.Button( output_dir_folder = gr.Button('📂', elem_id='open_folder_small')
'📂', elem_id='open_folder_small' output_dir_folder.click(get_folder_path, outputs=output_dir)
)
output_dir_folder.click(
get_folder_path, outputs=output_dir
)
logging_dir = gr.Textbox( logging_dir = gr.Textbox(
label='Logging folder', label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder', placeholder='Optional: enable logging and output TensorBoard log to this folder',
) )
logging_dir_folder = gr.Button( logging_dir_folder = gr.Button('📂', elem_id='open_folder_small')
'📂', elem_id='open_folder_small' logging_dir_folder.click(get_folder_path, outputs=logging_dir)
)
logging_dir_folder.click(
get_folder_path, outputs=logging_dir
)
with gr.Row(): with gr.Row():
output_name = gr.Textbox( output_name = gr.Textbox(
label='Model output name', label='Model output name',
@ -659,11 +653,13 @@ def lora_tab(
with gr.Row(): with gr.Row():
text_encoder_lr = gr.Textbox( text_encoder_lr = gr.Textbox(
label='Text Encoder learning rate', label='Text Encoder learning rate',
value="5e-5", value='5e-5',
placeholder='Optional', placeholder='Optional',
) )
unet_lr = gr.Textbox( unet_lr = gr.Textbox(
label='Unet learning rate', value="1e-3", placeholder='Optional' label='Unet learning rate',
value='1e-3',
placeholder='Optional',
) )
network_dim = gr.Slider( network_dim = gr.Slider(
minimum=1, minimum=1,
@ -731,13 +727,9 @@ def lora_tab(
label='Stop text encoder training', label='Stop text encoder training',
) )
with gr.Row(): with gr.Row():
enable_bucket = gr.Checkbox( enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
label='Enable buckets', value=True
)
cache_latent = gr.Checkbox(label='Cache latent', value=True) cache_latent = gr.Checkbox(label='Cache latent', value=True)
use_8bit_adam = gr.Checkbox( use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
label='Use 8bit adam', value=True
)
xformers = gr.Checkbox(label='Use xformers', value=True) xformers = gr.Checkbox(label='Use xformers', value=True)
with gr.Accordion('Advanced Configuration', open=False): with gr.Accordion('Advanced Configuration', open=False):
with gr.Row(): with gr.Row():
@ -777,33 +769,14 @@ def lora_tab(
mem_eff_attn = gr.Checkbox( mem_eff_attn = gr.Checkbox(
label='Memory efficient attention', value=False label='Memory efficient attention', value=False
) )
with gr.Row(): (
save_state = gr.Checkbox( save_state,
label='Save training state', value=False resume,
) max_token_length,
resume = gr.Textbox( max_train_epochs,
label='Resume from saved training state', max_data_loader_n_workers,
placeholder='path to "last-state" state folder to resume from', ) = gradio_advanced_training()
)
resume_button = gr.Button('📂', elem_id='open_folder_small')
resume_button.click(get_folder_path, outputs=resume)
# vae = gr.Textbox(
# label='VAE',
# placeholder='(Optiona) path to checkpoint of vae to replace for training',
# )
# vae_button = gr.Button('📂', elem_id='open_folder_small')
# vae_button.click(get_any_file_path, outputs=vae)
max_token_length = gr.Dropdown(
label='Max Token Length',
choices=[
'75',
'150',
'225',
],
value='75',
)
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
with gr.Tab('Tools'): with gr.Tab('Tools'):
gr.Markdown( gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...' 'This section provide Dreambooth tools to help setup your dataset...'