Move save_state and resume to common gui
Format code
This commit is contained in:
parent
6aed2bb402
commit
abccecb093
@ -334,8 +334,8 @@ def train_model(
|
||||
run_cmd += ' --xformers'
|
||||
if shuffle_caption:
|
||||
run_cmd += ' --shuffle_caption'
|
||||
if save_state:
|
||||
run_cmd += ' --save_state'
|
||||
# if save_state:
|
||||
# run_cmd += ' --save_state'
|
||||
if color_aug:
|
||||
run_cmd += ' --color_aug'
|
||||
if flip_aug:
|
||||
@ -368,8 +368,8 @@ def train_model(
|
||||
)
|
||||
if not save_model_as == 'same as source model':
|
||||
run_cmd += f' --save_model_as={save_model_as}'
|
||||
if not resume == '':
|
||||
run_cmd += f' --resume={resume}'
|
||||
# if not resume == '':
|
||||
# run_cmd += f' --resume={resume}'
|
||||
if not float(prior_loss_weight) == 1.0:
|
||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||
if int(clip_skip) > 1:
|
||||
@ -384,7 +384,13 @@ def train_model(
|
||||
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
||||
if not max_data_loader_n_workers == '':
|
||||
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
||||
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
max_train_epochs=max_train_epochs,
|
||||
max_data_loader_n_workers=max_data_loader_n_workers,
|
||||
max_token_length=max_token_length,
|
||||
resume=resume,
|
||||
save_state=save_state,
|
||||
)
|
||||
|
||||
print(run_cmd)
|
||||
# Run the command
|
||||
@ -681,9 +687,6 @@ def dreambooth_tab(
|
||||
label='Shuffle caption', value=False
|
||||
)
|
||||
with gr.Row():
|
||||
save_state = gr.Checkbox(
|
||||
label='Save training state', value=False
|
||||
)
|
||||
color_aug = gr.Checkbox(
|
||||
label='Color augmentation', value=False
|
||||
)
|
||||
@ -697,12 +700,6 @@ def dreambooth_tab(
|
||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||
)
|
||||
with gr.Row():
|
||||
resume = gr.Textbox(
|
||||
label='Resume from saved training state',
|
||||
placeholder='path to "last-state" state folder to resume from',
|
||||
)
|
||||
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
resume_button.click(get_folder_path, outputs=resume)
|
||||
prior_loss_weight = gr.Number(
|
||||
label='Prior loss weight', value=1.0
|
||||
)
|
||||
@ -712,25 +709,7 @@ def dreambooth_tab(
|
||||
)
|
||||
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
vae_button.click(get_any_file_path, outputs=vae)
|
||||
max_token_length = gr.Dropdown(
|
||||
label='Max Token Length',
|
||||
choices=[
|
||||
'75',
|
||||
'150',
|
||||
'225',
|
||||
],
|
||||
value='75',
|
||||
)
|
||||
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||
# with gr.Row():
|
||||
# max_train_epochs = gr.Textbox(
|
||||
# label='Max train epoch',
|
||||
# placeholder='(Optional) Override number of epoch',
|
||||
# )
|
||||
# max_data_loader_n_workers = gr.Textbox(
|
||||
# label='Max num workers for DataLoader',
|
||||
# placeholder='(Optional) Override number of epoch. Default: 8',
|
||||
# )
|
||||
save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||
with gr.Tab('Tools'):
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
|
@ -335,15 +335,21 @@ def train_model(
|
||||
run_cmd += f' --clip_skip={str(clip_skip)}'
|
||||
if int(gradient_accumulation_steps) > 1:
|
||||
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
||||
if save_state:
|
||||
run_cmd += ' --save_state'
|
||||
if not resume == '':
|
||||
run_cmd += f' --resume={resume}'
|
||||
# if save_state:
|
||||
# run_cmd += ' --save_state'
|
||||
# if not resume == '':
|
||||
# run_cmd += f' --resume={resume}'
|
||||
if not output_name == '':
|
||||
run_cmd += f' --output_name="{output_name}"'
|
||||
if (int(max_token_length) > 75):
|
||||
run_cmd += f' --max_token_length={max_token_length}'
|
||||
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
max_train_epochs=max_train_epochs,
|
||||
max_data_loader_n_workers=max_data_loader_n_workers,
|
||||
max_token_length=max_token_length,
|
||||
resume=resume,
|
||||
save_state=save_state,
|
||||
)
|
||||
|
||||
print(run_cmd)
|
||||
# Run the command
|
||||
@ -640,31 +646,13 @@ def finetune_tab():
|
||||
label='Shuffle caption', value=False
|
||||
)
|
||||
with gr.Row():
|
||||
save_state = gr.Checkbox(
|
||||
label='Save training state', value=False
|
||||
)
|
||||
resume = gr.Textbox(
|
||||
label='Resume from saved training state',
|
||||
placeholder='path to "last-state" state folder to resume from',
|
||||
)
|
||||
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
resume_button.click(get_folder_path, outputs=resume)
|
||||
gradient_checkpointing = gr.Checkbox(
|
||||
label='Gradient checkpointing', value=False
|
||||
)
|
||||
gradient_accumulation_steps = gr.Number(
|
||||
label='Gradient accumulate steps', value='1'
|
||||
)
|
||||
max_token_length = gr.Dropdown(
|
||||
label='Max Token Length',
|
||||
choices=[
|
||||
'75',
|
||||
'150',
|
||||
'225',
|
||||
],
|
||||
value='75',
|
||||
)
|
||||
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||
save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
create_caption = gr.Checkbox(
|
||||
|
@ -4,10 +4,12 @@ import gradio as gr
|
||||
from easygui import msgbox
|
||||
import shutil
|
||||
|
||||
|
||||
def get_dir_and_file(file_path):
|
||||
dir_path, file_name = os.path.split(file_path)
|
||||
return (dir_path, file_name)
|
||||
|
||||
|
||||
def has_ext_files(directory, extension):
|
||||
# Iterate through all the files in the directory
|
||||
for file in os.listdir(directory):
|
||||
@ -17,7 +19,10 @@ def has_ext_files(directory, extension):
|
||||
# If no extension files were found, return False
|
||||
return False
|
||||
|
||||
def get_file_path(file_path='', defaultextension='.json', extension_name='Config files'):
|
||||
|
||||
def get_file_path(
|
||||
file_path='', defaultextension='.json', extension_name='Config files'
|
||||
):
|
||||
current_file_path = file_path
|
||||
# print(f'current file path: {current_file_path}')
|
||||
|
||||
@ -27,8 +32,13 @@ def get_file_path(file_path='', defaultextension='.json', extension_name='Config
|
||||
root.wm_attributes('-topmost', 1)
|
||||
root.withdraw()
|
||||
file_path = filedialog.askopenfilename(
|
||||
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')),
|
||||
defaultextension=defaultextension, initialfile=initial_file, initialdir=initial_dir
|
||||
filetypes=(
|
||||
(f'{extension_name}', f'{defaultextension}'),
|
||||
('All files', '*'),
|
||||
),
|
||||
defaultextension=defaultextension,
|
||||
initialfile=initial_file,
|
||||
initialdir=initial_dir,
|
||||
)
|
||||
root.destroy()
|
||||
|
||||
@ -37,6 +47,7 @@ def get_file_path(file_path='', defaultextension='.json', extension_name='Config
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def get_any_file_path(file_path=''):
|
||||
current_file_path = file_path
|
||||
# print(f'current file path: {current_file_path}')
|
||||
@ -46,8 +57,10 @@ def get_any_file_path(file_path=''):
|
||||
root = Tk()
|
||||
root.wm_attributes('-topmost', 1)
|
||||
root.withdraw()
|
||||
file_path = filedialog.askopenfilename(initialdir=initial_dir,
|
||||
initialfile=initial_file,)
|
||||
file_path = filedialog.askopenfilename(
|
||||
initialdir=initial_dir,
|
||||
initialfile=initial_file,
|
||||
)
|
||||
root.destroy()
|
||||
|
||||
if file_path == '':
|
||||
@ -80,7 +93,9 @@ def get_folder_path(folder_path=''):
|
||||
return folder_path
|
||||
|
||||
|
||||
def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='Config files'):
|
||||
def get_saveasfile_path(
|
||||
file_path='', defaultextension='.json', extension_name='Config files'
|
||||
):
|
||||
current_file_path = file_path
|
||||
# print(f'current file path: {current_file_path}')
|
||||
|
||||
@ -90,7 +105,10 @@ def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='
|
||||
root.wm_attributes('-topmost', 1)
|
||||
root.withdraw()
|
||||
save_file_path = filedialog.asksaveasfile(
|
||||
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')),
|
||||
filetypes=(
|
||||
(f'{extension_name}', f'{defaultextension}'),
|
||||
('All files', '*'),
|
||||
),
|
||||
defaultextension=defaultextension,
|
||||
initialdir=initial_dir,
|
||||
initialfile=initial_file,
|
||||
@ -109,7 +127,10 @@ def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='
|
||||
|
||||
return file_path
|
||||
|
||||
def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config files'):
|
||||
|
||||
def get_saveasfilename_path(
|
||||
file_path='', extensions='*', extension_name='Config files'
|
||||
):
|
||||
current_file_path = file_path
|
||||
# print(f'current file path: {current_file_path}')
|
||||
|
||||
@ -118,7 +139,8 @@ def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config
|
||||
root = Tk()
|
||||
root.wm_attributes('-topmost', 1)
|
||||
root.withdraw()
|
||||
save_file_path = filedialog.asksaveasfilename(filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
|
||||
save_file_path = filedialog.asksaveasfilename(
|
||||
filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
|
||||
defaultextension=extensions,
|
||||
initialdir=initial_dir,
|
||||
initialfile=initial_file,
|
||||
@ -138,7 +160,9 @@ def add_pre_postfix(
|
||||
folder='', prefix='', postfix='', caption_file_ext='.caption'
|
||||
):
|
||||
if not has_ext_files(folder, caption_file_ext):
|
||||
msgbox(f'No files with extension {caption_file_ext} were found in {folder}...')
|
||||
msgbox(
|
||||
f'No files with extension {caption_file_ext} were found in {folder}...'
|
||||
)
|
||||
return
|
||||
|
||||
if prefix == '' and postfix == '':
|
||||
@ -158,12 +182,13 @@ def add_pre_postfix(
|
||||
f.write(f'{prefix}{content}{postfix}')
|
||||
f.close()
|
||||
|
||||
def find_replace(
|
||||
folder='', caption_file_ext='.caption', find='', replace=''
|
||||
):
|
||||
|
||||
def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
|
||||
print('Running caption find/replace')
|
||||
if not has_ext_files(folder, caption_file_ext):
|
||||
msgbox(f'No files with extension {caption_file_ext} were found in {folder}...')
|
||||
msgbox(
|
||||
f'No files with extension {caption_file_ext} were found in {folder}...'
|
||||
)
|
||||
return
|
||||
|
||||
if find == '':
|
||||
@ -179,13 +204,17 @@ def find_replace(
|
||||
f.write(content)
|
||||
f.close()
|
||||
|
||||
|
||||
def color_aug_changed(color_aug):
|
||||
if color_aug:
|
||||
msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...')
|
||||
msgbox(
|
||||
'Disabling "Cache latent" because "Color augmentation" has been selected...'
|
||||
)
|
||||
return gr.Checkbox.update(value=False, interactive=False)
|
||||
else:
|
||||
return gr.Checkbox.update(value=True, interactive=True)
|
||||
|
||||
|
||||
def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
||||
# List all files in the directory
|
||||
files = os.listdir(output_dir)
|
||||
@ -201,18 +230,23 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
||||
|
||||
# Copy the v2-inference-v.yaml file to the current file, with a .yaml extension
|
||||
if v2 and v_parameterization:
|
||||
print(f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml')
|
||||
print(
|
||||
f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml'
|
||||
)
|
||||
shutil.copy(
|
||||
f'./v2_inference/v2-inference-v.yaml',
|
||||
f'{output_dir}/{file_name}.yaml',
|
||||
)
|
||||
elif v2:
|
||||
print(f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml')
|
||||
print(
|
||||
f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml'
|
||||
)
|
||||
shutil.copy(
|
||||
f'./v2_inference/v2-inference.yaml',
|
||||
f'{output_dir}/{file_name}.yaml',
|
||||
)
|
||||
|
||||
|
||||
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
||||
# define a list of substrings to search for
|
||||
substrings_v2 = [
|
||||
@ -267,25 +301,58 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
||||
### Gradio common GUI section
|
||||
###
|
||||
|
||||
|
||||
def gradio_advanced_training():
|
||||
with gr.Row():
|
||||
max_train_epochs = gr.Textbox(
|
||||
label='Max train epoch',
|
||||
placeholder='(Optional) Override number of epoch',
|
||||
)
|
||||
max_data_loader_n_workers = gr.Textbox(
|
||||
label='Max num workers for DataLoader',
|
||||
placeholder='(Optional) Override number of epoch. Default: 8',
|
||||
)
|
||||
return max_train_epochs, max_data_loader_n_workers
|
||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||
resume = gr.Textbox(
|
||||
label='Resume from saved training state',
|
||||
placeholder='path to "last-state" state folder to resume from',
|
||||
)
|
||||
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
resume_button.click(get_folder_path, outputs=resume)
|
||||
max_token_length = gr.Dropdown(
|
||||
label='Max Token Length',
|
||||
choices=[
|
||||
'75',
|
||||
'150',
|
||||
'225',
|
||||
],
|
||||
value='75',
|
||||
)
|
||||
with gr.Row():
|
||||
max_train_epochs = gr.Textbox(
|
||||
label='Max train epoch',
|
||||
placeholder='(Optional) Override number of epoch',
|
||||
)
|
||||
max_data_loader_n_workers = gr.Textbox(
|
||||
label='Max num workers for DataLoader',
|
||||
placeholder='(Optional) Override number of epoch. Default: 8',
|
||||
)
|
||||
return (
|
||||
save_state,
|
||||
resume,
|
||||
max_token_length,
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
)
|
||||
|
||||
|
||||
def run_cmd_advanced_training(**kwargs):
|
||||
run_cmd = ''
|
||||
max_train_epochs = kwargs.get('max_train_epochs', '')
|
||||
max_data_loader_n_workers = kwargs.get('max_data_loader_n_workers', '')
|
||||
if not max_train_epochs == '':
|
||||
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
||||
if not max_data_loader_n_workers == '':
|
||||
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
||||
|
||||
options = [
|
||||
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
|
||||
if kwargs.get('max_train_epochs')
|
||||
else '',
|
||||
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
|
||||
if kwargs.get('max_data_loader_n_workers')
|
||||
else '',
|
||||
f' --max_token_length={kwargs.get("max_token_length", "")}'
|
||||
if int(kwargs.get('max_token_length', 0)) > 75
|
||||
else '',
|
||||
f' --resume="{kwargs.get("resume", "")}"'
|
||||
if kwargs.get('resume')
|
||||
else '',
|
||||
' --save_state' if kwargs.get('save_state') else '',
|
||||
]
|
||||
run_cmd = ''.join(options)
|
||||
return run_cmd
|
103
lora_gui.py
103
lora_gui.py
@ -19,7 +19,9 @@ from library.common_gui import (
|
||||
get_saveasfile_path,
|
||||
color_aug_changed,
|
||||
save_inference_file,
|
||||
set_pretrained_model_name_or_path_input, gradio_advanced_training,run_cmd_advanced_training,
|
||||
set_pretrained_model_name_or_path_input,
|
||||
gradio_advanced_training,
|
||||
run_cmd_advanced_training,
|
||||
)
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
@ -180,7 +182,7 @@ def open_configuration(
|
||||
# load variables from JSON file
|
||||
with open(file_path, 'r') as f:
|
||||
my_data = json.load(f)
|
||||
print("Loading config...")
|
||||
print('Loading config...')
|
||||
else:
|
||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||
my_data = {}
|
||||
@ -235,7 +237,7 @@ def train_model(
|
||||
gradient_accumulation_steps,
|
||||
mem_eff_attn,
|
||||
output_name,
|
||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||
max_token_length,
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
@ -350,8 +352,8 @@ def train_model(
|
||||
run_cmd += ' --xformers'
|
||||
if shuffle_caption:
|
||||
run_cmd += ' --shuffle_caption'
|
||||
if save_state:
|
||||
run_cmd += ' --save_state'
|
||||
# if save_state:
|
||||
# run_cmd += ' --save_state'
|
||||
if color_aug:
|
||||
run_cmd += ' --color_aug'
|
||||
if flip_aug:
|
||||
@ -386,8 +388,8 @@ def train_model(
|
||||
)
|
||||
if not save_model_as == 'same as source model':
|
||||
run_cmd += f' --save_model_as={save_model_as}'
|
||||
if not resume == '':
|
||||
run_cmd += f' --resume="{resume}"'
|
||||
# if not resume == '':
|
||||
# run_cmd += f' --resume="{resume}"'
|
||||
if not float(prior_loss_weight) == 1.0:
|
||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||
run_cmd += f' --network_module=networks.lora'
|
||||
@ -414,9 +416,15 @@ def train_model(
|
||||
# run_cmd += f' --vae="{vae}"'
|
||||
if not output_name == '':
|
||||
run_cmd += f' --output_name="{output_name}"'
|
||||
if (int(max_token_length) > 75):
|
||||
run_cmd += f' --max_token_length={max_token_length}'
|
||||
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
|
||||
# if (int(max_token_length) > 75):
|
||||
# run_cmd += f' --max_token_length={max_token_length}'
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
max_train_epochs=max_train_epochs,
|
||||
max_data_loader_n_workers=max_data_loader_n_workers,
|
||||
max_token_length=max_token_length,
|
||||
resume=resume,
|
||||
save_state=save_state,
|
||||
)
|
||||
|
||||
print(run_cmd)
|
||||
# Run the command
|
||||
@ -564,9 +572,7 @@ def lora_tab(
|
||||
label='Image folder',
|
||||
placeholder='Folder where the training folders containing the images are located',
|
||||
)
|
||||
train_data_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
train_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||
train_data_dir_folder.click(
|
||||
get_folder_path, outputs=train_data_dir
|
||||
)
|
||||
@ -574,33 +580,21 @@ def lora_tab(
|
||||
label='Regularisation folder',
|
||||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||
)
|
||||
reg_data_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
reg_data_dir_folder.click(
|
||||
get_folder_path, outputs=reg_data_dir
|
||||
)
|
||||
reg_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||
reg_data_dir_folder.click(get_folder_path, outputs=reg_data_dir)
|
||||
with gr.Row():
|
||||
output_dir = gr.Textbox(
|
||||
label='Output folder',
|
||||
placeholder='Folder to output trained model',
|
||||
)
|
||||
output_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
output_dir_folder.click(
|
||||
get_folder_path, outputs=output_dir
|
||||
)
|
||||
output_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||
output_dir_folder.click(get_folder_path, outputs=output_dir)
|
||||
logging_dir = gr.Textbox(
|
||||
label='Logging folder',
|
||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||
)
|
||||
logging_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
logging_dir_folder.click(
|
||||
get_folder_path, outputs=logging_dir
|
||||
)
|
||||
logging_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||
logging_dir_folder.click(get_folder_path, outputs=logging_dir)
|
||||
with gr.Row():
|
||||
output_name = gr.Textbox(
|
||||
label='Model output name',
|
||||
@ -659,11 +653,13 @@ def lora_tab(
|
||||
with gr.Row():
|
||||
text_encoder_lr = gr.Textbox(
|
||||
label='Text Encoder learning rate',
|
||||
value="5e-5",
|
||||
value='5e-5',
|
||||
placeholder='Optional',
|
||||
)
|
||||
unet_lr = gr.Textbox(
|
||||
label='Unet learning rate', value="1e-3", placeholder='Optional'
|
||||
label='Unet learning rate',
|
||||
value='1e-3',
|
||||
placeholder='Optional',
|
||||
)
|
||||
network_dim = gr.Slider(
|
||||
minimum=1,
|
||||
@ -731,13 +727,9 @@ def lora_tab(
|
||||
label='Stop text encoder training',
|
||||
)
|
||||
with gr.Row():
|
||||
enable_bucket = gr.Checkbox(
|
||||
label='Enable buckets', value=True
|
||||
)
|
||||
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
||||
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
||||
use_8bit_adam = gr.Checkbox(
|
||||
label='Use 8bit adam', value=True
|
||||
)
|
||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||
with gr.Accordion('Advanced Configuration', open=False):
|
||||
with gr.Row():
|
||||
@ -777,32 +769,13 @@ def lora_tab(
|
||||
mem_eff_attn = gr.Checkbox(
|
||||
label='Memory efficient attention', value=False
|
||||
)
|
||||
with gr.Row():
|
||||
save_state = gr.Checkbox(
|
||||
label='Save training state', value=False
|
||||
)
|
||||
resume = gr.Textbox(
|
||||
label='Resume from saved training state',
|
||||
placeholder='path to "last-state" state folder to resume from',
|
||||
)
|
||||
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
resume_button.click(get_folder_path, outputs=resume)
|
||||
# vae = gr.Textbox(
|
||||
# label='VAE',
|
||||
# placeholder='(Optiona) path to checkpoint of vae to replace for training',
|
||||
# )
|
||||
# vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
# vae_button.click(get_any_file_path, outputs=vae)
|
||||
max_token_length = gr.Dropdown(
|
||||
label='Max Token Length',
|
||||
choices=[
|
||||
'75',
|
||||
'150',
|
||||
'225',
|
||||
],
|
||||
value='75',
|
||||
)
|
||||
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||
(
|
||||
save_state,
|
||||
resume,
|
||||
max_token_length,
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
) = gradio_advanced_training()
|
||||
|
||||
with gr.Tab('Tools'):
|
||||
gr.Markdown(
|
||||
|
Loading…
Reference in New Issue
Block a user