Move save_state and resume to common gui
Format code
This commit is contained in:
parent
6aed2bb402
commit
abccecb093
@ -334,8 +334,8 @@ def train_model(
|
|||||||
run_cmd += ' --xformers'
|
run_cmd += ' --xformers'
|
||||||
if shuffle_caption:
|
if shuffle_caption:
|
||||||
run_cmd += ' --shuffle_caption'
|
run_cmd += ' --shuffle_caption'
|
||||||
if save_state:
|
# if save_state:
|
||||||
run_cmd += ' --save_state'
|
# run_cmd += ' --save_state'
|
||||||
if color_aug:
|
if color_aug:
|
||||||
run_cmd += ' --color_aug'
|
run_cmd += ' --color_aug'
|
||||||
if flip_aug:
|
if flip_aug:
|
||||||
@ -368,8 +368,8 @@ def train_model(
|
|||||||
)
|
)
|
||||||
if not save_model_as == 'same as source model':
|
if not save_model_as == 'same as source model':
|
||||||
run_cmd += f' --save_model_as={save_model_as}'
|
run_cmd += f' --save_model_as={save_model_as}'
|
||||||
if not resume == '':
|
# if not resume == '':
|
||||||
run_cmd += f' --resume={resume}'
|
# run_cmd += f' --resume={resume}'
|
||||||
if not float(prior_loss_weight) == 1.0:
|
if not float(prior_loss_weight) == 1.0:
|
||||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||||
if int(clip_skip) > 1:
|
if int(clip_skip) > 1:
|
||||||
@ -384,7 +384,13 @@ def train_model(
|
|||||||
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
||||||
if not max_data_loader_n_workers == '':
|
if not max_data_loader_n_workers == '':
|
||||||
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
||||||
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
|
run_cmd += run_cmd_advanced_training(
|
||||||
|
max_train_epochs=max_train_epochs,
|
||||||
|
max_data_loader_n_workers=max_data_loader_n_workers,
|
||||||
|
max_token_length=max_token_length,
|
||||||
|
resume=resume,
|
||||||
|
save_state=save_state,
|
||||||
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -681,9 +687,6 @@ def dreambooth_tab(
|
|||||||
label='Shuffle caption', value=False
|
label='Shuffle caption', value=False
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_state = gr.Checkbox(
|
|
||||||
label='Save training state', value=False
|
|
||||||
)
|
|
||||||
color_aug = gr.Checkbox(
|
color_aug = gr.Checkbox(
|
||||||
label='Color augmentation', value=False
|
label='Color augmentation', value=False
|
||||||
)
|
)
|
||||||
@ -697,12 +700,6 @@ def dreambooth_tab(
|
|||||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resume = gr.Textbox(
|
|
||||||
label='Resume from saved training state',
|
|
||||||
placeholder='path to "last-state" state folder to resume from',
|
|
||||||
)
|
|
||||||
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
|
||||||
resume_button.click(get_folder_path, outputs=resume)
|
|
||||||
prior_loss_weight = gr.Number(
|
prior_loss_weight = gr.Number(
|
||||||
label='Prior loss weight', value=1.0
|
label='Prior loss weight', value=1.0
|
||||||
)
|
)
|
||||||
@ -712,25 +709,7 @@ def dreambooth_tab(
|
|||||||
)
|
)
|
||||||
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||||
vae_button.click(get_any_file_path, outputs=vae)
|
vae_button.click(get_any_file_path, outputs=vae)
|
||||||
max_token_length = gr.Dropdown(
|
save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||||
label='Max Token Length',
|
|
||||||
choices=[
|
|
||||||
'75',
|
|
||||||
'150',
|
|
||||||
'225',
|
|
||||||
],
|
|
||||||
value='75',
|
|
||||||
)
|
|
||||||
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
|
||||||
# with gr.Row():
|
|
||||||
# max_train_epochs = gr.Textbox(
|
|
||||||
# label='Max train epoch',
|
|
||||||
# placeholder='(Optional) Override number of epoch',
|
|
||||||
# )
|
|
||||||
# max_data_loader_n_workers = gr.Textbox(
|
|
||||||
# label='Max num workers for DataLoader',
|
|
||||||
# placeholder='(Optional) Override number of epoch. Default: 8',
|
|
||||||
# )
|
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'This section provide Dreambooth tools to help setup your dataset...'
|
'This section provide Dreambooth tools to help setup your dataset...'
|
||||||
|
@ -335,15 +335,21 @@ def train_model(
|
|||||||
run_cmd += f' --clip_skip={str(clip_skip)}'
|
run_cmd += f' --clip_skip={str(clip_skip)}'
|
||||||
if int(gradient_accumulation_steps) > 1:
|
if int(gradient_accumulation_steps) > 1:
|
||||||
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
||||||
if save_state:
|
# if save_state:
|
||||||
run_cmd += ' --save_state'
|
# run_cmd += ' --save_state'
|
||||||
if not resume == '':
|
# if not resume == '':
|
||||||
run_cmd += f' --resume={resume}'
|
# run_cmd += f' --resume={resume}'
|
||||||
if not output_name == '':
|
if not output_name == '':
|
||||||
run_cmd += f' --output_name="{output_name}"'
|
run_cmd += f' --output_name="{output_name}"'
|
||||||
if (int(max_token_length) > 75):
|
if (int(max_token_length) > 75):
|
||||||
run_cmd += f' --max_token_length={max_token_length}'
|
run_cmd += f' --max_token_length={max_token_length}'
|
||||||
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
|
run_cmd += run_cmd_advanced_training(
|
||||||
|
max_train_epochs=max_train_epochs,
|
||||||
|
max_data_loader_n_workers=max_data_loader_n_workers,
|
||||||
|
max_token_length=max_token_length,
|
||||||
|
resume=resume,
|
||||||
|
save_state=save_state,
|
||||||
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -640,31 +646,13 @@ def finetune_tab():
|
|||||||
label='Shuffle caption', value=False
|
label='Shuffle caption', value=False
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_state = gr.Checkbox(
|
|
||||||
label='Save training state', value=False
|
|
||||||
)
|
|
||||||
resume = gr.Textbox(
|
|
||||||
label='Resume from saved training state',
|
|
||||||
placeholder='path to "last-state" state folder to resume from',
|
|
||||||
)
|
|
||||||
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
|
||||||
resume_button.click(get_folder_path, outputs=resume)
|
|
||||||
gradient_checkpointing = gr.Checkbox(
|
gradient_checkpointing = gr.Checkbox(
|
||||||
label='Gradient checkpointing', value=False
|
label='Gradient checkpointing', value=False
|
||||||
)
|
)
|
||||||
gradient_accumulation_steps = gr.Number(
|
gradient_accumulation_steps = gr.Number(
|
||||||
label='Gradient accumulate steps', value='1'
|
label='Gradient accumulate steps', value='1'
|
||||||
)
|
)
|
||||||
max_token_length = gr.Dropdown(
|
save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||||
label='Max Token Length',
|
|
||||||
choices=[
|
|
||||||
'75',
|
|
||||||
'150',
|
|
||||||
'225',
|
|
||||||
],
|
|
||||||
value='75',
|
|
||||||
)
|
|
||||||
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
create_caption = gr.Checkbox(
|
create_caption = gr.Checkbox(
|
||||||
|
@ -4,10 +4,12 @@ import gradio as gr
|
|||||||
from easygui import msgbox
|
from easygui import msgbox
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
def get_dir_and_file(file_path):
|
def get_dir_and_file(file_path):
|
||||||
dir_path, file_name = os.path.split(file_path)
|
dir_path, file_name = os.path.split(file_path)
|
||||||
return (dir_path, file_name)
|
return (dir_path, file_name)
|
||||||
|
|
||||||
|
|
||||||
def has_ext_files(directory, extension):
|
def has_ext_files(directory, extension):
|
||||||
# Iterate through all the files in the directory
|
# Iterate through all the files in the directory
|
||||||
for file in os.listdir(directory):
|
for file in os.listdir(directory):
|
||||||
@ -17,7 +19,10 @@ def has_ext_files(directory, extension):
|
|||||||
# If no extension files were found, return False
|
# If no extension files were found, return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_file_path(file_path='', defaultextension='.json', extension_name='Config files'):
|
|
||||||
|
def get_file_path(
|
||||||
|
file_path='', defaultextension='.json', extension_name='Config files'
|
||||||
|
):
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
# print(f'current file path: {current_file_path}')
|
# print(f'current file path: {current_file_path}')
|
||||||
|
|
||||||
@ -27,8 +32,13 @@ def get_file_path(file_path='', defaultextension='.json', extension_name='Config
|
|||||||
root.wm_attributes('-topmost', 1)
|
root.wm_attributes('-topmost', 1)
|
||||||
root.withdraw()
|
root.withdraw()
|
||||||
file_path = filedialog.askopenfilename(
|
file_path = filedialog.askopenfilename(
|
||||||
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')),
|
filetypes=(
|
||||||
defaultextension=defaultextension, initialfile=initial_file, initialdir=initial_dir
|
(f'{extension_name}', f'{defaultextension}'),
|
||||||
|
('All files', '*'),
|
||||||
|
),
|
||||||
|
defaultextension=defaultextension,
|
||||||
|
initialfile=initial_file,
|
||||||
|
initialdir=initial_dir,
|
||||||
)
|
)
|
||||||
root.destroy()
|
root.destroy()
|
||||||
|
|
||||||
@ -37,6 +47,7 @@ def get_file_path(file_path='', defaultextension='.json', extension_name='Config
|
|||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
def get_any_file_path(file_path=''):
|
def get_any_file_path(file_path=''):
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
# print(f'current file path: {current_file_path}')
|
# print(f'current file path: {current_file_path}')
|
||||||
@ -46,8 +57,10 @@ def get_any_file_path(file_path=''):
|
|||||||
root = Tk()
|
root = Tk()
|
||||||
root.wm_attributes('-topmost', 1)
|
root.wm_attributes('-topmost', 1)
|
||||||
root.withdraw()
|
root.withdraw()
|
||||||
file_path = filedialog.askopenfilename(initialdir=initial_dir,
|
file_path = filedialog.askopenfilename(
|
||||||
initialfile=initial_file,)
|
initialdir=initial_dir,
|
||||||
|
initialfile=initial_file,
|
||||||
|
)
|
||||||
root.destroy()
|
root.destroy()
|
||||||
|
|
||||||
if file_path == '':
|
if file_path == '':
|
||||||
@ -80,7 +93,9 @@ def get_folder_path(folder_path=''):
|
|||||||
return folder_path
|
return folder_path
|
||||||
|
|
||||||
|
|
||||||
def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='Config files'):
|
def get_saveasfile_path(
|
||||||
|
file_path='', defaultextension='.json', extension_name='Config files'
|
||||||
|
):
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
# print(f'current file path: {current_file_path}')
|
# print(f'current file path: {current_file_path}')
|
||||||
|
|
||||||
@ -90,7 +105,10 @@ def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='
|
|||||||
root.wm_attributes('-topmost', 1)
|
root.wm_attributes('-topmost', 1)
|
||||||
root.withdraw()
|
root.withdraw()
|
||||||
save_file_path = filedialog.asksaveasfile(
|
save_file_path = filedialog.asksaveasfile(
|
||||||
filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')),
|
filetypes=(
|
||||||
|
(f'{extension_name}', f'{defaultextension}'),
|
||||||
|
('All files', '*'),
|
||||||
|
),
|
||||||
defaultextension=defaultextension,
|
defaultextension=defaultextension,
|
||||||
initialdir=initial_dir,
|
initialdir=initial_dir,
|
||||||
initialfile=initial_file,
|
initialfile=initial_file,
|
||||||
@ -109,7 +127,10 @@ def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='
|
|||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config files'):
|
|
||||||
|
def get_saveasfilename_path(
|
||||||
|
file_path='', extensions='*', extension_name='Config files'
|
||||||
|
):
|
||||||
current_file_path = file_path
|
current_file_path = file_path
|
||||||
# print(f'current file path: {current_file_path}')
|
# print(f'current file path: {current_file_path}')
|
||||||
|
|
||||||
@ -118,7 +139,8 @@ def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config
|
|||||||
root = Tk()
|
root = Tk()
|
||||||
root.wm_attributes('-topmost', 1)
|
root.wm_attributes('-topmost', 1)
|
||||||
root.withdraw()
|
root.withdraw()
|
||||||
save_file_path = filedialog.asksaveasfilename(filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
|
save_file_path = filedialog.asksaveasfilename(
|
||||||
|
filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
|
||||||
defaultextension=extensions,
|
defaultextension=extensions,
|
||||||
initialdir=initial_dir,
|
initialdir=initial_dir,
|
||||||
initialfile=initial_file,
|
initialfile=initial_file,
|
||||||
@ -138,7 +160,9 @@ def add_pre_postfix(
|
|||||||
folder='', prefix='', postfix='', caption_file_ext='.caption'
|
folder='', prefix='', postfix='', caption_file_ext='.caption'
|
||||||
):
|
):
|
||||||
if not has_ext_files(folder, caption_file_ext):
|
if not has_ext_files(folder, caption_file_ext):
|
||||||
msgbox(f'No files with extension {caption_file_ext} were found in {folder}...')
|
msgbox(
|
||||||
|
f'No files with extension {caption_file_ext} were found in {folder}...'
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if prefix == '' and postfix == '':
|
if prefix == '' and postfix == '':
|
||||||
@ -158,12 +182,13 @@ def add_pre_postfix(
|
|||||||
f.write(f'{prefix}{content}{postfix}')
|
f.write(f'{prefix}{content}{postfix}')
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
def find_replace(
|
|
||||||
folder='', caption_file_ext='.caption', find='', replace=''
|
def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
|
||||||
):
|
|
||||||
print('Running caption find/replace')
|
print('Running caption find/replace')
|
||||||
if not has_ext_files(folder, caption_file_ext):
|
if not has_ext_files(folder, caption_file_ext):
|
||||||
msgbox(f'No files with extension {caption_file_ext} were found in {folder}...')
|
msgbox(
|
||||||
|
f'No files with extension {caption_file_ext} were found in {folder}...'
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if find == '':
|
if find == '':
|
||||||
@ -179,13 +204,17 @@ def find_replace(
|
|||||||
f.write(content)
|
f.write(content)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
def color_aug_changed(color_aug):
|
def color_aug_changed(color_aug):
|
||||||
if color_aug:
|
if color_aug:
|
||||||
msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...')
|
msgbox(
|
||||||
|
'Disabling "Cache latent" because "Color augmentation" has been selected...'
|
||||||
|
)
|
||||||
return gr.Checkbox.update(value=False, interactive=False)
|
return gr.Checkbox.update(value=False, interactive=False)
|
||||||
else:
|
else:
|
||||||
return gr.Checkbox.update(value=True, interactive=True)
|
return gr.Checkbox.update(value=True, interactive=True)
|
||||||
|
|
||||||
|
|
||||||
def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
||||||
# List all files in the directory
|
# List all files in the directory
|
||||||
files = os.listdir(output_dir)
|
files = os.listdir(output_dir)
|
||||||
@ -201,18 +230,23 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
|||||||
|
|
||||||
# Copy the v2-inference-v.yaml file to the current file, with a .yaml extension
|
# Copy the v2-inference-v.yaml file to the current file, with a .yaml extension
|
||||||
if v2 and v_parameterization:
|
if v2 and v_parameterization:
|
||||||
print(f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml')
|
print(
|
||||||
|
f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml'
|
||||||
|
)
|
||||||
shutil.copy(
|
shutil.copy(
|
||||||
f'./v2_inference/v2-inference-v.yaml',
|
f'./v2_inference/v2-inference-v.yaml',
|
||||||
f'{output_dir}/{file_name}.yaml',
|
f'{output_dir}/{file_name}.yaml',
|
||||||
)
|
)
|
||||||
elif v2:
|
elif v2:
|
||||||
print(f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml')
|
print(
|
||||||
|
f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml'
|
||||||
|
)
|
||||||
shutil.copy(
|
shutil.copy(
|
||||||
f'./v2_inference/v2-inference.yaml',
|
f'./v2_inference/v2-inference.yaml',
|
||||||
f'{output_dir}/{file_name}.yaml',
|
f'{output_dir}/{file_name}.yaml',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
||||||
# define a list of substrings to search for
|
# define a list of substrings to search for
|
||||||
substrings_v2 = [
|
substrings_v2 = [
|
||||||
@ -267,7 +301,25 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
|||||||
### Gradio common GUI section
|
### Gradio common GUI section
|
||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
def gradio_advanced_training():
|
def gradio_advanced_training():
|
||||||
|
with gr.Row():
|
||||||
|
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||||
|
resume = gr.Textbox(
|
||||||
|
label='Resume from saved training state',
|
||||||
|
placeholder='path to "last-state" state folder to resume from',
|
||||||
|
)
|
||||||
|
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
||||||
|
resume_button.click(get_folder_path, outputs=resume)
|
||||||
|
max_token_length = gr.Dropdown(
|
||||||
|
label='Max Token Length',
|
||||||
|
choices=[
|
||||||
|
'75',
|
||||||
|
'150',
|
||||||
|
'225',
|
||||||
|
],
|
||||||
|
value='75',
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_train_epochs = gr.Textbox(
|
max_train_epochs = gr.Textbox(
|
||||||
label='Max train epoch',
|
label='Max train epoch',
|
||||||
@ -277,15 +329,30 @@ def gradio_advanced_training():
|
|||||||
label='Max num workers for DataLoader',
|
label='Max num workers for DataLoader',
|
||||||
placeholder='(Optional) Override number of epoch. Default: 8',
|
placeholder='(Optional) Override number of epoch. Default: 8',
|
||||||
)
|
)
|
||||||
return max_train_epochs, max_data_loader_n_workers
|
return (
|
||||||
|
save_state,
|
||||||
|
resume,
|
||||||
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_cmd_advanced_training(**kwargs):
|
def run_cmd_advanced_training(**kwargs):
|
||||||
run_cmd = ''
|
options = [
|
||||||
max_train_epochs = kwargs.get('max_train_epochs', '')
|
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
|
||||||
max_data_loader_n_workers = kwargs.get('max_data_loader_n_workers', '')
|
if kwargs.get('max_train_epochs')
|
||||||
if not max_train_epochs == '':
|
else '',
|
||||||
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
|
||||||
if not max_data_loader_n_workers == '':
|
if kwargs.get('max_data_loader_n_workers')
|
||||||
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
else '',
|
||||||
|
f' --max_token_length={kwargs.get("max_token_length", "")}'
|
||||||
|
if int(kwargs.get('max_token_length', 0)) > 75
|
||||||
|
else '',
|
||||||
|
f' --resume="{kwargs.get("resume", "")}"'
|
||||||
|
if kwargs.get('resume')
|
||||||
|
else '',
|
||||||
|
' --save_state' if kwargs.get('save_state') else '',
|
||||||
|
]
|
||||||
|
run_cmd = ''.join(options)
|
||||||
return run_cmd
|
return run_cmd
|
101
lora_gui.py
101
lora_gui.py
@ -19,7 +19,9 @@ from library.common_gui import (
|
|||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
save_inference_file,
|
save_inference_file,
|
||||||
set_pretrained_model_name_or_path_input, gradio_advanced_training,run_cmd_advanced_training,
|
set_pretrained_model_name_or_path_input,
|
||||||
|
gradio_advanced_training,
|
||||||
|
run_cmd_advanced_training,
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
@ -180,7 +182,7 @@ def open_configuration(
|
|||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data = json.load(f)
|
my_data = json.load(f)
|
||||||
print("Loading config...")
|
print('Loading config...')
|
||||||
else:
|
else:
|
||||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
my_data = {}
|
my_data = {}
|
||||||
@ -350,8 +352,8 @@ def train_model(
|
|||||||
run_cmd += ' --xformers'
|
run_cmd += ' --xformers'
|
||||||
if shuffle_caption:
|
if shuffle_caption:
|
||||||
run_cmd += ' --shuffle_caption'
|
run_cmd += ' --shuffle_caption'
|
||||||
if save_state:
|
# if save_state:
|
||||||
run_cmd += ' --save_state'
|
# run_cmd += ' --save_state'
|
||||||
if color_aug:
|
if color_aug:
|
||||||
run_cmd += ' --color_aug'
|
run_cmd += ' --color_aug'
|
||||||
if flip_aug:
|
if flip_aug:
|
||||||
@ -386,8 +388,8 @@ def train_model(
|
|||||||
)
|
)
|
||||||
if not save_model_as == 'same as source model':
|
if not save_model_as == 'same as source model':
|
||||||
run_cmd += f' --save_model_as={save_model_as}'
|
run_cmd += f' --save_model_as={save_model_as}'
|
||||||
if not resume == '':
|
# if not resume == '':
|
||||||
run_cmd += f' --resume="{resume}"'
|
# run_cmd += f' --resume="{resume}"'
|
||||||
if not float(prior_loss_weight) == 1.0:
|
if not float(prior_loss_weight) == 1.0:
|
||||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||||
run_cmd += f' --network_module=networks.lora'
|
run_cmd += f' --network_module=networks.lora'
|
||||||
@ -414,9 +416,15 @@ def train_model(
|
|||||||
# run_cmd += f' --vae="{vae}"'
|
# run_cmd += f' --vae="{vae}"'
|
||||||
if not output_name == '':
|
if not output_name == '':
|
||||||
run_cmd += f' --output_name="{output_name}"'
|
run_cmd += f' --output_name="{output_name}"'
|
||||||
if (int(max_token_length) > 75):
|
# if (int(max_token_length) > 75):
|
||||||
run_cmd += f' --max_token_length={max_token_length}'
|
# run_cmd += f' --max_token_length={max_token_length}'
|
||||||
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
|
run_cmd += run_cmd_advanced_training(
|
||||||
|
max_train_epochs=max_train_epochs,
|
||||||
|
max_data_loader_n_workers=max_data_loader_n_workers,
|
||||||
|
max_token_length=max_token_length,
|
||||||
|
resume=resume,
|
||||||
|
save_state=save_state,
|
||||||
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -564,9 +572,7 @@ def lora_tab(
|
|||||||
label='Image folder',
|
label='Image folder',
|
||||||
placeholder='Folder where the training folders containing the images are located',
|
placeholder='Folder where the training folders containing the images are located',
|
||||||
)
|
)
|
||||||
train_data_dir_folder = gr.Button(
|
train_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||||
'📂', elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
train_data_dir_folder.click(
|
train_data_dir_folder.click(
|
||||||
get_folder_path, outputs=train_data_dir
|
get_folder_path, outputs=train_data_dir
|
||||||
)
|
)
|
||||||
@ -574,33 +580,21 @@ def lora_tab(
|
|||||||
label='Regularisation folder',
|
label='Regularisation folder',
|
||||||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||||
)
|
)
|
||||||
reg_data_dir_folder = gr.Button(
|
reg_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||||
'📂', elem_id='open_folder_small'
|
reg_data_dir_folder.click(get_folder_path, outputs=reg_data_dir)
|
||||||
)
|
|
||||||
reg_data_dir_folder.click(
|
|
||||||
get_folder_path, outputs=reg_data_dir
|
|
||||||
)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
output_dir = gr.Textbox(
|
output_dir = gr.Textbox(
|
||||||
label='Output folder',
|
label='Output folder',
|
||||||
placeholder='Folder to output trained model',
|
placeholder='Folder to output trained model',
|
||||||
)
|
)
|
||||||
output_dir_folder = gr.Button(
|
output_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||||
'📂', elem_id='open_folder_small'
|
output_dir_folder.click(get_folder_path, outputs=output_dir)
|
||||||
)
|
|
||||||
output_dir_folder.click(
|
|
||||||
get_folder_path, outputs=output_dir
|
|
||||||
)
|
|
||||||
logging_dir = gr.Textbox(
|
logging_dir = gr.Textbox(
|
||||||
label='Logging folder',
|
label='Logging folder',
|
||||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||||
)
|
)
|
||||||
logging_dir_folder = gr.Button(
|
logging_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||||
'📂', elem_id='open_folder_small'
|
logging_dir_folder.click(get_folder_path, outputs=logging_dir)
|
||||||
)
|
|
||||||
logging_dir_folder.click(
|
|
||||||
get_folder_path, outputs=logging_dir
|
|
||||||
)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
output_name = gr.Textbox(
|
output_name = gr.Textbox(
|
||||||
label='Model output name',
|
label='Model output name',
|
||||||
@ -659,11 +653,13 @@ def lora_tab(
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_encoder_lr = gr.Textbox(
|
text_encoder_lr = gr.Textbox(
|
||||||
label='Text Encoder learning rate',
|
label='Text Encoder learning rate',
|
||||||
value="5e-5",
|
value='5e-5',
|
||||||
placeholder='Optional',
|
placeholder='Optional',
|
||||||
)
|
)
|
||||||
unet_lr = gr.Textbox(
|
unet_lr = gr.Textbox(
|
||||||
label='Unet learning rate', value="1e-3", placeholder='Optional'
|
label='Unet learning rate',
|
||||||
|
value='1e-3',
|
||||||
|
placeholder='Optional',
|
||||||
)
|
)
|
||||||
network_dim = gr.Slider(
|
network_dim = gr.Slider(
|
||||||
minimum=1,
|
minimum=1,
|
||||||
@ -731,13 +727,9 @@ def lora_tab(
|
|||||||
label='Stop text encoder training',
|
label='Stop text encoder training',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
enable_bucket = gr.Checkbox(
|
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
||||||
label='Enable buckets', value=True
|
|
||||||
)
|
|
||||||
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
||||||
use_8bit_adam = gr.Checkbox(
|
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||||
label='Use 8bit adam', value=True
|
|
||||||
)
|
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
with gr.Accordion('Advanced Configuration', open=False):
|
with gr.Accordion('Advanced Configuration', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -777,32 +769,13 @@ def lora_tab(
|
|||||||
mem_eff_attn = gr.Checkbox(
|
mem_eff_attn = gr.Checkbox(
|
||||||
label='Memory efficient attention', value=False
|
label='Memory efficient attention', value=False
|
||||||
)
|
)
|
||||||
with gr.Row():
|
(
|
||||||
save_state = gr.Checkbox(
|
save_state,
|
||||||
label='Save training state', value=False
|
resume,
|
||||||
)
|
max_token_length,
|
||||||
resume = gr.Textbox(
|
max_train_epochs,
|
||||||
label='Resume from saved training state',
|
max_data_loader_n_workers,
|
||||||
placeholder='path to "last-state" state folder to resume from',
|
) = gradio_advanced_training()
|
||||||
)
|
|
||||||
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
|
||||||
resume_button.click(get_folder_path, outputs=resume)
|
|
||||||
# vae = gr.Textbox(
|
|
||||||
# label='VAE',
|
|
||||||
# placeholder='(Optiona) path to checkpoint of vae to replace for training',
|
|
||||||
# )
|
|
||||||
# vae_button = gr.Button('📂', elem_id='open_folder_small')
|
|
||||||
# vae_button.click(get_any_file_path, outputs=vae)
|
|
||||||
max_token_length = gr.Dropdown(
|
|
||||||
label='Max Token Length',
|
|
||||||
choices=[
|
|
||||||
'75',
|
|
||||||
'150',
|
|
||||||
'225',
|
|
||||||
],
|
|
||||||
value='75',
|
|
||||||
)
|
|
||||||
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
|
||||||
|
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
|
Loading…
Reference in New Issue
Block a user