refactor Dreambooth gui code
This commit is contained in:
parent
442eb7a292
commit
402cb51ec0
@ -536,10 +536,10 @@ def UI(username, password):
|
||||
|
||||
|
||||
def dreambooth_tab(
|
||||
train_data_dir_input=gr.Textbox(),
|
||||
reg_data_dir_input=gr.Textbox(),
|
||||
output_dir_input=gr.Textbox(),
|
||||
logging_dir_input=gr.Textbox(),
|
||||
train_data_dir=gr.Textbox(),
|
||||
reg_data_dir=gr.Textbox(),
|
||||
output_dir=gr.Textbox(),
|
||||
logging_dir=gr.Textbox(),
|
||||
):
|
||||
dummy_db_true = gr.Label(value=True, visible=False)
|
||||
dummy_db_false = gr.Label(value=False, visible=False)
|
||||
@ -564,24 +564,24 @@ def dreambooth_tab(
|
||||
with gr.Tab('Source model'):
|
||||
# Define the input elements
|
||||
with gr.Row():
|
||||
pretrained_model_name_or_path_input = gr.Textbox(
|
||||
pretrained_model_name_or_path = gr.Textbox(
|
||||
label='Pretrained model name or path',
|
||||
placeholder='enter the path to custom model or name of pretrained model',
|
||||
)
|
||||
pretrained_model_name_or_path_fille = gr.Button(
|
||||
pretrained_model_name_or_path_file = gr.Button(
|
||||
document_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
pretrained_model_name_or_path_fille.click(
|
||||
pretrained_model_name_or_path_file.click(
|
||||
get_any_file_path,
|
||||
inputs=[pretrained_model_name_or_path_input],
|
||||
outputs=pretrained_model_name_or_path_input,
|
||||
inputs=[pretrained_model_name_or_path],
|
||||
outputs=pretrained_model_name_or_path,
|
||||
)
|
||||
pretrained_model_name_or_path_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
pretrained_model_name_or_path_folder.click(
|
||||
get_folder_path,
|
||||
outputs=pretrained_model_name_or_path_input,
|
||||
outputs=pretrained_model_name_or_path,
|
||||
)
|
||||
model_list = gr.Dropdown(
|
||||
label='(Optional) Model Quick Pick',
|
||||
@ -595,7 +595,7 @@ def dreambooth_tab(
|
||||
'CompVis/stable-diffusion-v1-4',
|
||||
],
|
||||
)
|
||||
save_model_as_dropdown = gr.Dropdown(
|
||||
save_model_as = gr.Dropdown(
|
||||
label='Save trained model as',
|
||||
choices=[
|
||||
'same as source model',
|
||||
@ -607,28 +607,28 @@ def dreambooth_tab(
|
||||
value='same as source model',
|
||||
)
|
||||
with gr.Row():
|
||||
v2_input = gr.Checkbox(label='v2', value=True)
|
||||
v_parameterization_input = gr.Checkbox(
|
||||
v2 = gr.Checkbox(label='v2', value=True)
|
||||
v_parameterization = gr.Checkbox(
|
||||
label='v_parameterization', value=False
|
||||
)
|
||||
pretrained_model_name_or_path_input.change(
|
||||
pretrained_model_name_or_path.change(
|
||||
remove_doublequote,
|
||||
inputs=[pretrained_model_name_or_path_input],
|
||||
outputs=[pretrained_model_name_or_path_input],
|
||||
inputs=[pretrained_model_name_or_path],
|
||||
outputs=[pretrained_model_name_or_path],
|
||||
)
|
||||
model_list.change(
|
||||
set_pretrained_model_name_or_path_input,
|
||||
inputs=[model_list, v2_input, v_parameterization_input],
|
||||
inputs=[model_list, v2, v_parameterization],
|
||||
outputs=[
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
],
|
||||
)
|
||||
|
||||
with gr.Tab('Folders'):
|
||||
with gr.Row():
|
||||
train_data_dir_input = gr.Textbox(
|
||||
train_data_dir = gr.Textbox(
|
||||
label='Image folder',
|
||||
placeholder='Folder where the training folders containing the images are located',
|
||||
)
|
||||
@ -636,9 +636,9 @@ def dreambooth_tab(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
train_data_dir_input_folder.click(
|
||||
get_folder_path, outputs=train_data_dir_input
|
||||
get_folder_path, outputs=train_data_dir
|
||||
)
|
||||
reg_data_dir_input = gr.Textbox(
|
||||
reg_data_dir = gr.Textbox(
|
||||
label='Regularisation folder',
|
||||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||
)
|
||||
@ -646,20 +646,20 @@ def dreambooth_tab(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
reg_data_dir_input_folder.click(
|
||||
get_folder_path, outputs=reg_data_dir_input
|
||||
get_folder_path, outputs=reg_data_dir
|
||||
)
|
||||
with gr.Row():
|
||||
output_dir_input = gr.Textbox(
|
||||
label='Output folder',
|
||||
output_dir = gr.Textbox(
|
||||
label='Model output folder',
|
||||
placeholder='Folder to output trained model',
|
||||
)
|
||||
output_dir_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
output_dir_input_folder.click(
|
||||
get_folder_path, outputs=output_dir_input
|
||||
get_folder_path, outputs=output_dir
|
||||
)
|
||||
logging_dir_input = gr.Textbox(
|
||||
logging_dir = gr.Textbox(
|
||||
label='Logging folder',
|
||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||
)
|
||||
@ -667,32 +667,32 @@ def dreambooth_tab(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
logging_dir_input_folder.click(
|
||||
get_folder_path, outputs=logging_dir_input
|
||||
get_folder_path, outputs=logging_dir
|
||||
)
|
||||
train_data_dir_input.change(
|
||||
train_data_dir.change(
|
||||
remove_doublequote,
|
||||
inputs=[train_data_dir_input],
|
||||
outputs=[train_data_dir_input],
|
||||
inputs=[train_data_dir],
|
||||
outputs=[train_data_dir],
|
||||
)
|
||||
reg_data_dir_input.change(
|
||||
reg_data_dir.change(
|
||||
remove_doublequote,
|
||||
inputs=[reg_data_dir_input],
|
||||
outputs=[reg_data_dir_input],
|
||||
inputs=[reg_data_dir],
|
||||
outputs=[reg_data_dir],
|
||||
)
|
||||
output_dir_input.change(
|
||||
output_dir.change(
|
||||
remove_doublequote,
|
||||
inputs=[output_dir_input],
|
||||
outputs=[output_dir_input],
|
||||
inputs=[output_dir],
|
||||
outputs=[output_dir],
|
||||
)
|
||||
logging_dir_input.change(
|
||||
logging_dir.change(
|
||||
remove_doublequote,
|
||||
inputs=[logging_dir_input],
|
||||
outputs=[logging_dir_input],
|
||||
inputs=[logging_dir],
|
||||
outputs=[logging_dir],
|
||||
)
|
||||
with gr.Tab('Training parameters'):
|
||||
with gr.Row():
|
||||
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
||||
lr_scheduler_input = gr.Dropdown(
|
||||
learning_rate = gr.Textbox(label='Learning rate', value=1e-6)
|
||||
lr_scheduler = gr.Dropdown(
|
||||
label='LR Scheduler',
|
||||
choices=[
|
||||
'constant',
|
||||
@ -704,21 +704,21 @@ def dreambooth_tab(
|
||||
],
|
||||
value='constant',
|
||||
)
|
||||
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
||||
lr_warmup = gr.Textbox(label='LR warmup', value=0)
|
||||
with gr.Row():
|
||||
train_batch_size_input = gr.Slider(
|
||||
train_batch_size = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=32,
|
||||
label='Train batch size',
|
||||
value=1,
|
||||
step=1,
|
||||
)
|
||||
epoch_input = gr.Textbox(label='Epoch', value=1)
|
||||
save_every_n_epochs_input = gr.Textbox(
|
||||
epoch = gr.Textbox(label='Epoch', value=1)
|
||||
save_every_n_epochs = gr.Textbox(
|
||||
label='Save every N epochs', value=1
|
||||
)
|
||||
with gr.Row():
|
||||
mixed_precision_input = gr.Dropdown(
|
||||
mixed_precision = gr.Dropdown(
|
||||
label='Mixed precision',
|
||||
choices=[
|
||||
'no',
|
||||
@ -727,7 +727,7 @@ def dreambooth_tab(
|
||||
],
|
||||
value='fp16',
|
||||
)
|
||||
save_precision_input = gr.Dropdown(
|
||||
save_precision = gr.Dropdown(
|
||||
label='Save precision',
|
||||
choices=[
|
||||
'float',
|
||||
@ -736,7 +736,7 @@ def dreambooth_tab(
|
||||
],
|
||||
value='fp16',
|
||||
)
|
||||
num_cpu_threads_per_process_input = gr.Slider(
|
||||
num_cpu_threads_per_process = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=os.cpu_count(),
|
||||
step=1,
|
||||
@ -744,18 +744,18 @@ def dreambooth_tab(
|
||||
value=os.cpu_count(),
|
||||
)
|
||||
with gr.Row():
|
||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
||||
max_resolution_input = gr.Textbox(
|
||||
seed = gr.Textbox(label='Seed', value=1234)
|
||||
max_resolution = gr.Textbox(
|
||||
label='Max resolution',
|
||||
value='512,512',
|
||||
placeholder='512,512',
|
||||
)
|
||||
with gr.Row():
|
||||
caption_extention_input = gr.Textbox(
|
||||
caption_extention = gr.Textbox(
|
||||
label='Caption Extension',
|
||||
placeholder='(Optional) Extension for caption files. default: .caption',
|
||||
)
|
||||
stop_text_encoder_training_input = gr.Slider(
|
||||
stop_text_encoder_training = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=100,
|
||||
value=0,
|
||||
@ -763,24 +763,24 @@ def dreambooth_tab(
|
||||
label='Stop text encoder training',
|
||||
)
|
||||
with gr.Row():
|
||||
enable_bucket_input = gr.Checkbox(
|
||||
enable_bucket = gr.Checkbox(
|
||||
label='Enable buckets', value=True
|
||||
)
|
||||
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
|
||||
use_8bit_adam_input = gr.Checkbox(
|
||||
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
||||
use_8bit_adam = gr.Checkbox(
|
||||
label='Use 8bit adam', value=True
|
||||
)
|
||||
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||
with gr.Accordion('Advanced Configuration', open=False):
|
||||
with gr.Row():
|
||||
full_fp16_input = gr.Checkbox(
|
||||
full_fp16 = gr.Checkbox(
|
||||
label='Full fp16 training (experimental)', value=False
|
||||
)
|
||||
no_token_padding_input = gr.Checkbox(
|
||||
no_token_padding = gr.Checkbox(
|
||||
label='No token padding', value=False
|
||||
)
|
||||
|
||||
gradient_checkpointing_input = gr.Checkbox(
|
||||
gradient_checkpointing = gr.Checkbox(
|
||||
label='Gradient checkpointing', value=False
|
||||
)
|
||||
|
||||
@ -798,7 +798,7 @@ def dreambooth_tab(
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[color_aug],
|
||||
outputs=[cache_latent_input],
|
||||
outputs=[cache_latent],
|
||||
)
|
||||
clip_skip = gr.Slider(
|
||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||
@ -824,43 +824,43 @@ def dreambooth_tab(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
)
|
||||
gradio_dreambooth_folder_creation_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
train_data_dir_input=train_data_dir,
|
||||
reg_data_dir_input=reg_data_dir,
|
||||
output_dir_input=output_dir,
|
||||
logging_dir_input=logging_dir,
|
||||
)
|
||||
|
||||
button_run = gr.Button('Train model')
|
||||
|
||||
settings_list = [
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
logging_dir_input,
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
max_resolution_input,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
cache_latent_input,
|
||||
caption_extention_input,
|
||||
enable_bucket_input,
|
||||
gradient_checkpointing_input,
|
||||
full_fp16_input,
|
||||
no_token_padding_input,
|
||||
stop_text_encoder_training_input,
|
||||
use_8bit_adam_input,
|
||||
xformers_input,
|
||||
save_model_as_dropdown,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
output_dir,
|
||||
max_resolution,
|
||||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
mixed_precision,
|
||||
save_precision,
|
||||
seed,
|
||||
num_cpu_threads_per_process,
|
||||
cache_latent,
|
||||
caption_extention,
|
||||
enable_bucket,
|
||||
gradient_checkpointing,
|
||||
full_fp16,
|
||||
no_token_padding,
|
||||
stop_text_encoder_training,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
save_model_as,
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
@ -895,10 +895,10 @@ def dreambooth_tab(
|
||||
)
|
||||
|
||||
return (
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
output_dir,
|
||||
logging_dir,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user