Improve GUI

This commit is contained in:
bmaltais 2023-01-01 00:33:29 -05:00
parent 6b41f4f951
commit 2c069d1067
2 changed files with 37 additions and 32 deletions

View File

@ -26,10 +26,10 @@ def UI(username, password):
output_dir_input, output_dir_input,
logging_dir_input, logging_dir_input,
) = dreambooth_tab() ) = dreambooth_tab()
with gr.Tab('Dreambooth LoRA'):
lora_tab()
with gr.Tab('Finetune'): with gr.Tab('Finetune'):
finetune_tab() finetune_tab()
with gr.Tab('LoRA'):
lora_tab()
with gr.Tab('Utilities'): with gr.Tab('Utilities'):
utilities_tab( utilities_tab(
train_data_dir_input=train_data_dir_input, train_data_dir_input=train_data_dir_input,

View File

@ -42,7 +42,6 @@ def save_configuration(
reg_data_dir, reg_data_dir,
output_dir, output_dir,
max_resolution, max_resolution,
learning_rate,
lr_scheduler, lr_scheduler,
lr_warmup, lr_warmup,
train_batch_size, train_batch_size,
@ -65,7 +64,7 @@ def save_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim prior_loss_weight, text_encoder_lr, unet_lr, network_dim
): ):
original_file_path = file_path original_file_path = file_path
@ -94,7 +93,6 @@ def save_configuration(
'reg_data_dir': reg_data_dir, 'reg_data_dir': reg_data_dir,
'output_dir': output_dir, 'output_dir': output_dir,
'max_resolution': max_resolution, 'max_resolution': max_resolution,
'learning_rate': learning_rate,
'lr_scheduler': lr_scheduler, 'lr_scheduler': lr_scheduler,
'lr_warmup': lr_warmup, 'lr_warmup': lr_warmup,
'train_batch_size': train_batch_size, 'train_batch_size': train_batch_size,
@ -120,7 +118,6 @@ def save_configuration(
'prior_loss_weight': prior_loss_weight, 'prior_loss_weight': prior_loss_weight,
'text_encoder_lr': text_encoder_lr, 'text_encoder_lr': text_encoder_lr,
'unet_lr': unet_lr, 'unet_lr': unet_lr,
'network_train': network_train,
'network_dim': network_dim 'network_dim': network_dim
} }
@ -141,7 +138,6 @@ def open_configuration(
reg_data_dir, reg_data_dir,
output_dir, output_dir,
max_resolution, max_resolution,
learning_rate,
lr_scheduler, lr_scheduler,
lr_warmup, lr_warmup,
train_batch_size, train_batch_size,
@ -164,7 +160,7 @@ def open_configuration(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim prior_loss_weight, text_encoder_lr, unet_lr, network_dim
): ):
original_file_path = file_path original_file_path = file_path
@ -192,7 +188,6 @@ def open_configuration(
my_data.get('reg_data_dir', reg_data_dir), my_data.get('reg_data_dir', reg_data_dir),
my_data.get('output_dir', output_dir), my_data.get('output_dir', output_dir),
my_data.get('max_resolution', max_resolution), my_data.get('max_resolution', max_resolution),
my_data.get('learning_rate', learning_rate),
my_data.get('lr_scheduler', lr_scheduler), my_data.get('lr_scheduler', lr_scheduler),
my_data.get('lr_warmup', lr_warmup), my_data.get('lr_warmup', lr_warmup),
my_data.get('train_batch_size', train_batch_size), my_data.get('train_batch_size', train_batch_size),
@ -220,7 +215,6 @@ def open_configuration(
my_data.get('prior_loss_weight', prior_loss_weight), my_data.get('prior_loss_weight', prior_loss_weight),
my_data.get('text_encoder_lr', text_encoder_lr), my_data.get('text_encoder_lr', text_encoder_lr),
my_data.get('unet_lr', unet_lr), my_data.get('unet_lr', unet_lr),
my_data.get('network_train', network_train),
my_data.get('network_dim', network_dim), my_data.get('network_dim', network_dim),
) )
@ -234,7 +228,6 @@ def train_model(
reg_data_dir, reg_data_dir,
output_dir, output_dir,
max_resolution, max_resolution,
learning_rate,
lr_scheduler, lr_scheduler,
lr_warmup, lr_warmup,
train_batch_size, train_batch_size,
@ -257,7 +250,7 @@ def train_model(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim prior_loss_weight, text_encoder_lr, unet_lr, network_dim
): ):
def save_inference_file(output_dir, v2, v_parameterization): def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required # Copy inference model for v2 if required
@ -294,6 +287,14 @@ def train_model(
if output_dir == '': if output_dir == '':
msgbox('Output folder path is missing') msgbox('Output folder path is missing')
return return
# If string is empty set string to 0.
if text_encoder_lr == '': text_encoder_lr = 0
if unet_lr == '': unet_lr = 0
if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
msgbox('At least one Learning Rate value for "Text encoder" or "Unet" need to be provided')
return
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir
subfolders = [ subfolders = [
@ -394,7 +395,7 @@ def train_model(
run_cmd += f' --resolution={max_resolution}' run_cmd += f' --resolution={max_resolution}'
run_cmd += f' --output_dir={output_dir}' run_cmd += f' --output_dir={output_dir}'
run_cmd += f' --train_batch_size={train_batch_size}' run_cmd += f' --train_batch_size={train_batch_size}'
run_cmd += f' --learning_rate={learning_rate}' # run_cmd += f' --learning_rate={learning_rate}'
run_cmd += f' --lr_scheduler={lr_scheduler}' run_cmd += f' --lr_scheduler={lr_scheduler}'
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}' run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
run_cmd += f' --max_train_steps={max_train_steps}' run_cmd += f' --max_train_steps={max_train_steps}'
@ -418,14 +419,18 @@ def train_model(
if not float(prior_loss_weight) == 1.0: if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --prior_loss_weight={prior_loss_weight}'
run_cmd += f' --network_module=networks.lora' run_cmd += f' --network_module=networks.lora'
if not text_encoder_lr == '': if not float(text_encoder_lr) == 0:
run_cmd += f' --text_encoder_lr={text_encoder_lr}' run_cmd += f' --text_encoder_lr={text_encoder_lr}'
if not unet_lr == '': else:
run_cmd += f' --unet_lr={unet_lr}'
if network_train == 'Text encoder only':
run_cmd += f' --network_train_text_encoder_only'
elif network_train == 'Unet only':
run_cmd += f' --network_train_unet_only' run_cmd += f' --network_train_unet_only'
if not float(unet_lr) == 0:
run_cmd += f' --unet_lr={unet_lr}'
else:
run_cmd += f' --network_train_text_encoder_only'
# if network_train == 'Text encoder only':
# run_cmd += f' --network_train_text_encoder_only'
# elif network_train == 'Unet only':
# run_cmd += f' --network_train_unet_only'
run_cmd += f' --network_dim={network_dim}' run_cmd += f' --network_dim={network_dim}'
@ -695,7 +700,7 @@ def lora_tab(
) )
with gr.Tab('Training parameters'): with gr.Tab('Training parameters'):
with gr.Row(): with gr.Row():
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-4) # learning_rate_input = gr.Textbox(label='Learning rate', value=1e-4, visible=False)
lr_scheduler_input = gr.Dropdown( lr_scheduler_input = gr.Dropdown(
label='LR Scheduler', label='LR Scheduler',
choices=[ choices=[
@ -712,16 +717,16 @@ def lora_tab(
with gr.Row(): with gr.Row():
text_encoder_lr = gr.Textbox(label='Text Encoder learning rate', value=1e-6, placeholder='Optional') text_encoder_lr = gr.Textbox(label='Text Encoder learning rate', value=1e-6, placeholder='Optional')
unet_lr = gr.Textbox(label='Unet learning rate', value=1e-4, placeholder='Optional') unet_lr = gr.Textbox(label='Unet learning rate', value=1e-4, placeholder='Optional')
network_train =gr.Dropdown( # network_train = gr.Dropdown(
label='Network to train', # label='Network to train',
choices=[ # choices=[
'Text encoder and Unet', # 'Text encoder and Unet',
'Text encoder only', # 'Text encoder only',
'Unet only', # 'Unet only',
], # ],
value='Text encoder and Unet', # value='Text encoder and Unet',
interactive=True # interactive=True
) # )
network_dim = gr.Slider( network_dim = gr.Slider(
minimum=1, minimum=1,
maximum=32, maximum=32,
@ -846,7 +851,7 @@ def lora_tab(
reg_data_dir_input, reg_data_dir_input,
output_dir_input, output_dir_input,
max_resolution_input, max_resolution_input,
learning_rate_input, # learning_rate_input,
lr_scheduler_input, lr_scheduler_input,
lr_warmup_input, lr_warmup_input,
train_batch_size_input, train_batch_size_input,
@ -869,7 +874,7 @@ def lora_tab(
shuffle_caption, shuffle_caption,
save_state, save_state,
resume, resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim prior_loss_weight, text_encoder_lr, unet_lr, network_dim
] ]
button_open_config.click( button_open_config.click(