refactor Dreambooth gui code
This commit is contained in:
parent
442eb7a292
commit
402cb51ec0
@ -536,10 +536,10 @@ def UI(username, password):
|
|||||||
|
|
||||||
|
|
||||||
def dreambooth_tab(
|
def dreambooth_tab(
|
||||||
train_data_dir_input=gr.Textbox(),
|
train_data_dir=gr.Textbox(),
|
||||||
reg_data_dir_input=gr.Textbox(),
|
reg_data_dir=gr.Textbox(),
|
||||||
output_dir_input=gr.Textbox(),
|
output_dir=gr.Textbox(),
|
||||||
logging_dir_input=gr.Textbox(),
|
logging_dir=gr.Textbox(),
|
||||||
):
|
):
|
||||||
dummy_db_true = gr.Label(value=True, visible=False)
|
dummy_db_true = gr.Label(value=True, visible=False)
|
||||||
dummy_db_false = gr.Label(value=False, visible=False)
|
dummy_db_false = gr.Label(value=False, visible=False)
|
||||||
@ -564,24 +564,24 @@ def dreambooth_tab(
|
|||||||
with gr.Tab('Source model'):
|
with gr.Tab('Source model'):
|
||||||
# Define the input elements
|
# Define the input elements
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
pretrained_model_name_or_path_input = gr.Textbox(
|
pretrained_model_name_or_path = gr.Textbox(
|
||||||
label='Pretrained model name or path',
|
label='Pretrained model name or path',
|
||||||
placeholder='enter the path to custom model or name of pretrained model',
|
placeholder='enter the path to custom model or name of pretrained model',
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_fille = gr.Button(
|
pretrained_model_name_or_path_file = gr.Button(
|
||||||
document_symbol, elem_id='open_folder_small'
|
document_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_fille.click(
|
pretrained_model_name_or_path_file.click(
|
||||||
get_any_file_path,
|
get_any_file_path,
|
||||||
inputs=[pretrained_model_name_or_path_input],
|
inputs=[pretrained_model_name_or_path],
|
||||||
outputs=pretrained_model_name_or_path_input,
|
outputs=pretrained_model_name_or_path,
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_folder = gr.Button(
|
pretrained_model_name_or_path_folder = gr.Button(
|
||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_folder.click(
|
pretrained_model_name_or_path_folder.click(
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
outputs=pretrained_model_name_or_path_input,
|
outputs=pretrained_model_name_or_path,
|
||||||
)
|
)
|
||||||
model_list = gr.Dropdown(
|
model_list = gr.Dropdown(
|
||||||
label='(Optional) Model Quick Pick',
|
label='(Optional) Model Quick Pick',
|
||||||
@ -595,7 +595,7 @@ def dreambooth_tab(
|
|||||||
'CompVis/stable-diffusion-v1-4',
|
'CompVis/stable-diffusion-v1-4',
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
save_model_as_dropdown = gr.Dropdown(
|
save_model_as = gr.Dropdown(
|
||||||
label='Save trained model as',
|
label='Save trained model as',
|
||||||
choices=[
|
choices=[
|
||||||
'same as source model',
|
'same as source model',
|
||||||
@ -607,28 +607,28 @@ def dreambooth_tab(
|
|||||||
value='same as source model',
|
value='same as source model',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
v2_input = gr.Checkbox(label='v2', value=True)
|
v2 = gr.Checkbox(label='v2', value=True)
|
||||||
v_parameterization_input = gr.Checkbox(
|
v_parameterization = gr.Checkbox(
|
||||||
label='v_parameterization', value=False
|
label='v_parameterization', value=False
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_input.change(
|
pretrained_model_name_or_path.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[pretrained_model_name_or_path_input],
|
inputs=[pretrained_model_name_or_path],
|
||||||
outputs=[pretrained_model_name_or_path_input],
|
outputs=[pretrained_model_name_or_path],
|
||||||
)
|
)
|
||||||
model_list.change(
|
model_list.change(
|
||||||
set_pretrained_model_name_or_path_input,
|
set_pretrained_model_name_or_path_input,
|
||||||
inputs=[model_list, v2_input, v_parameterization_input],
|
inputs=[model_list, v2, v_parameterization],
|
||||||
outputs=[
|
outputs=[
|
||||||
pretrained_model_name_or_path_input,
|
pretrained_model_name_or_path,
|
||||||
v2_input,
|
v2,
|
||||||
v_parameterization_input,
|
v_parameterization,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Tab('Folders'):
|
with gr.Tab('Folders'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_data_dir_input = gr.Textbox(
|
train_data_dir = gr.Textbox(
|
||||||
label='Image folder',
|
label='Image folder',
|
||||||
placeholder='Folder where the training folders containing the images are located',
|
placeholder='Folder where the training folders containing the images are located',
|
||||||
)
|
)
|
||||||
@ -636,9 +636,9 @@ def dreambooth_tab(
|
|||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
train_data_dir_input_folder.click(
|
train_data_dir_input_folder.click(
|
||||||
get_folder_path, outputs=train_data_dir_input
|
get_folder_path, outputs=train_data_dir
|
||||||
)
|
)
|
||||||
reg_data_dir_input = gr.Textbox(
|
reg_data_dir = gr.Textbox(
|
||||||
label='Regularisation folder',
|
label='Regularisation folder',
|
||||||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||||
)
|
)
|
||||||
@ -646,20 +646,20 @@ def dreambooth_tab(
|
|||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
reg_data_dir_input_folder.click(
|
reg_data_dir_input_folder.click(
|
||||||
get_folder_path, outputs=reg_data_dir_input
|
get_folder_path, outputs=reg_data_dir
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
output_dir_input = gr.Textbox(
|
output_dir = gr.Textbox(
|
||||||
label='Output folder',
|
label='Model output folder',
|
||||||
placeholder='Folder to output trained model',
|
placeholder='Folder to output trained model',
|
||||||
)
|
)
|
||||||
output_dir_input_folder = gr.Button(
|
output_dir_input_folder = gr.Button(
|
||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
output_dir_input_folder.click(
|
output_dir_input_folder.click(
|
||||||
get_folder_path, outputs=output_dir_input
|
get_folder_path, outputs=output_dir
|
||||||
)
|
)
|
||||||
logging_dir_input = gr.Textbox(
|
logging_dir = gr.Textbox(
|
||||||
label='Logging folder',
|
label='Logging folder',
|
||||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||||
)
|
)
|
||||||
@ -667,32 +667,32 @@ def dreambooth_tab(
|
|||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
logging_dir_input_folder.click(
|
logging_dir_input_folder.click(
|
||||||
get_folder_path, outputs=logging_dir_input
|
get_folder_path, outputs=logging_dir
|
||||||
)
|
)
|
||||||
train_data_dir_input.change(
|
train_data_dir.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[train_data_dir_input],
|
inputs=[train_data_dir],
|
||||||
outputs=[train_data_dir_input],
|
outputs=[train_data_dir],
|
||||||
)
|
)
|
||||||
reg_data_dir_input.change(
|
reg_data_dir.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[reg_data_dir_input],
|
inputs=[reg_data_dir],
|
||||||
outputs=[reg_data_dir_input],
|
outputs=[reg_data_dir],
|
||||||
)
|
)
|
||||||
output_dir_input.change(
|
output_dir.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[output_dir_input],
|
inputs=[output_dir],
|
||||||
outputs=[output_dir_input],
|
outputs=[output_dir],
|
||||||
)
|
)
|
||||||
logging_dir_input.change(
|
logging_dir.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[logging_dir_input],
|
inputs=[logging_dir],
|
||||||
outputs=[logging_dir_input],
|
outputs=[logging_dir],
|
||||||
)
|
)
|
||||||
with gr.Tab('Training parameters'):
|
with gr.Tab('Training parameters'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
learning_rate = gr.Textbox(label='Learning rate', value=1e-6)
|
||||||
lr_scheduler_input = gr.Dropdown(
|
lr_scheduler = gr.Dropdown(
|
||||||
label='LR Scheduler',
|
label='LR Scheduler',
|
||||||
choices=[
|
choices=[
|
||||||
'constant',
|
'constant',
|
||||||
@ -704,21 +704,21 @@ def dreambooth_tab(
|
|||||||
],
|
],
|
||||||
value='constant',
|
value='constant',
|
||||||
)
|
)
|
||||||
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
lr_warmup = gr.Textbox(label='LR warmup', value=0)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_batch_size_input = gr.Slider(
|
train_batch_size = gr.Slider(
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=32,
|
maximum=32,
|
||||||
label='Train batch size',
|
label='Train batch size',
|
||||||
value=1,
|
value=1,
|
||||||
step=1,
|
step=1,
|
||||||
)
|
)
|
||||||
epoch_input = gr.Textbox(label='Epoch', value=1)
|
epoch = gr.Textbox(label='Epoch', value=1)
|
||||||
save_every_n_epochs_input = gr.Textbox(
|
save_every_n_epochs = gr.Textbox(
|
||||||
label='Save every N epochs', value=1
|
label='Save every N epochs', value=1
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
mixed_precision_input = gr.Dropdown(
|
mixed_precision = gr.Dropdown(
|
||||||
label='Mixed precision',
|
label='Mixed precision',
|
||||||
choices=[
|
choices=[
|
||||||
'no',
|
'no',
|
||||||
@ -727,7 +727,7 @@ def dreambooth_tab(
|
|||||||
],
|
],
|
||||||
value='fp16',
|
value='fp16',
|
||||||
)
|
)
|
||||||
save_precision_input = gr.Dropdown(
|
save_precision = gr.Dropdown(
|
||||||
label='Save precision',
|
label='Save precision',
|
||||||
choices=[
|
choices=[
|
||||||
'float',
|
'float',
|
||||||
@ -736,7 +736,7 @@ def dreambooth_tab(
|
|||||||
],
|
],
|
||||||
value='fp16',
|
value='fp16',
|
||||||
)
|
)
|
||||||
num_cpu_threads_per_process_input = gr.Slider(
|
num_cpu_threads_per_process = gr.Slider(
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=os.cpu_count(),
|
maximum=os.cpu_count(),
|
||||||
step=1,
|
step=1,
|
||||||
@ -744,18 +744,18 @@ def dreambooth_tab(
|
|||||||
value=os.cpu_count(),
|
value=os.cpu_count(),
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
seed = gr.Textbox(label='Seed', value=1234)
|
||||||
max_resolution_input = gr.Textbox(
|
max_resolution = gr.Textbox(
|
||||||
label='Max resolution',
|
label='Max resolution',
|
||||||
value='512,512',
|
value='512,512',
|
||||||
placeholder='512,512',
|
placeholder='512,512',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
caption_extention_input = gr.Textbox(
|
caption_extention = gr.Textbox(
|
||||||
label='Caption Extension',
|
label='Caption Extension',
|
||||||
placeholder='(Optional) Extension for caption files. default: .caption',
|
placeholder='(Optional) Extension for caption files. default: .caption',
|
||||||
)
|
)
|
||||||
stop_text_encoder_training_input = gr.Slider(
|
stop_text_encoder_training = gr.Slider(
|
||||||
minimum=0,
|
minimum=0,
|
||||||
maximum=100,
|
maximum=100,
|
||||||
value=0,
|
value=0,
|
||||||
@ -763,24 +763,24 @@ def dreambooth_tab(
|
|||||||
label='Stop text encoder training',
|
label='Stop text encoder training',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
enable_bucket_input = gr.Checkbox(
|
enable_bucket = gr.Checkbox(
|
||||||
label='Enable buckets', value=True
|
label='Enable buckets', value=True
|
||||||
)
|
)
|
||||||
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
|
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
||||||
use_8bit_adam_input = gr.Checkbox(
|
use_8bit_adam = gr.Checkbox(
|
||||||
label='Use 8bit adam', value=True
|
label='Use 8bit adam', value=True
|
||||||
)
|
)
|
||||||
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
with gr.Accordion('Advanced Configuration', open=False):
|
with gr.Accordion('Advanced Configuration', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
full_fp16_input = gr.Checkbox(
|
full_fp16 = gr.Checkbox(
|
||||||
label='Full fp16 training (experimental)', value=False
|
label='Full fp16 training (experimental)', value=False
|
||||||
)
|
)
|
||||||
no_token_padding_input = gr.Checkbox(
|
no_token_padding = gr.Checkbox(
|
||||||
label='No token padding', value=False
|
label='No token padding', value=False
|
||||||
)
|
)
|
||||||
|
|
||||||
gradient_checkpointing_input = gr.Checkbox(
|
gradient_checkpointing = gr.Checkbox(
|
||||||
label='Gradient checkpointing', value=False
|
label='Gradient checkpointing', value=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -798,7 +798,7 @@ def dreambooth_tab(
|
|||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
inputs=[color_aug],
|
inputs=[color_aug],
|
||||||
outputs=[cache_latent_input],
|
outputs=[cache_latent],
|
||||||
)
|
)
|
||||||
clip_skip = gr.Slider(
|
clip_skip = gr.Slider(
|
||||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||||
@ -824,43 +824,43 @@ def dreambooth_tab(
|
|||||||
'This section provide Dreambooth tools to help setup your dataset...'
|
'This section provide Dreambooth tools to help setup your dataset...'
|
||||||
)
|
)
|
||||||
gradio_dreambooth_folder_creation_tab(
|
gradio_dreambooth_folder_creation_tab(
|
||||||
train_data_dir_input=train_data_dir_input,
|
train_data_dir_input=train_data_dir,
|
||||||
reg_data_dir_input=reg_data_dir_input,
|
reg_data_dir_input=reg_data_dir,
|
||||||
output_dir_input=output_dir_input,
|
output_dir_input=output_dir,
|
||||||
logging_dir_input=logging_dir_input,
|
logging_dir_input=logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_run = gr.Button('Train model')
|
button_run = gr.Button('Train model')
|
||||||
|
|
||||||
settings_list = [
|
settings_list = [
|
||||||
pretrained_model_name_or_path_input,
|
pretrained_model_name_or_path,
|
||||||
v2_input,
|
v2,
|
||||||
v_parameterization_input,
|
v_parameterization,
|
||||||
logging_dir_input,
|
logging_dir,
|
||||||
train_data_dir_input,
|
train_data_dir,
|
||||||
reg_data_dir_input,
|
reg_data_dir,
|
||||||
output_dir_input,
|
output_dir,
|
||||||
max_resolution_input,
|
max_resolution,
|
||||||
learning_rate_input,
|
learning_rate,
|
||||||
lr_scheduler_input,
|
lr_scheduler,
|
||||||
lr_warmup_input,
|
lr_warmup,
|
||||||
train_batch_size_input,
|
train_batch_size,
|
||||||
epoch_input,
|
epoch,
|
||||||
save_every_n_epochs_input,
|
save_every_n_epochs,
|
||||||
mixed_precision_input,
|
mixed_precision,
|
||||||
save_precision_input,
|
save_precision,
|
||||||
seed_input,
|
seed,
|
||||||
num_cpu_threads_per_process_input,
|
num_cpu_threads_per_process,
|
||||||
cache_latent_input,
|
cache_latent,
|
||||||
caption_extention_input,
|
caption_extention,
|
||||||
enable_bucket_input,
|
enable_bucket,
|
||||||
gradient_checkpointing_input,
|
gradient_checkpointing,
|
||||||
full_fp16_input,
|
full_fp16,
|
||||||
no_token_padding_input,
|
no_token_padding,
|
||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam,
|
||||||
xformers_input,
|
xformers,
|
||||||
save_model_as_dropdown,
|
save_model_as,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
@ -895,10 +895,10 @@ def dreambooth_tab(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
train_data_dir_input,
|
train_data_dir,
|
||||||
reg_data_dir_input,
|
reg_data_dir,
|
||||||
output_dir_input,
|
output_dir,
|
||||||
logging_dir_input,
|
logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user