Improve GUI
This commit is contained in:
parent
6b41f4f951
commit
2c069d1067
@ -26,10 +26,10 @@ def UI(username, password):
|
|||||||
output_dir_input,
|
output_dir_input,
|
||||||
logging_dir_input,
|
logging_dir_input,
|
||||||
) = dreambooth_tab()
|
) = dreambooth_tab()
|
||||||
|
with gr.Tab('Dreambooth LoRA'):
|
||||||
|
lora_tab()
|
||||||
with gr.Tab('Finetune'):
|
with gr.Tab('Finetune'):
|
||||||
finetune_tab()
|
finetune_tab()
|
||||||
with gr.Tab('LoRA'):
|
|
||||||
lora_tab()
|
|
||||||
with gr.Tab('Utilities'):
|
with gr.Tab('Utilities'):
|
||||||
utilities_tab(
|
utilities_tab(
|
||||||
train_data_dir_input=train_data_dir_input,
|
train_data_dir_input=train_data_dir_input,
|
||||||
|
65
lora_gui.py
65
lora_gui.py
@ -42,7 +42,6 @@ def save_configuration(
|
|||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
learning_rate,
|
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
@ -65,7 +64,7 @@ def save_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
@ -94,7 +93,6 @@ def save_configuration(
|
|||||||
'reg_data_dir': reg_data_dir,
|
'reg_data_dir': reg_data_dir,
|
||||||
'output_dir': output_dir,
|
'output_dir': output_dir,
|
||||||
'max_resolution': max_resolution,
|
'max_resolution': max_resolution,
|
||||||
'learning_rate': learning_rate,
|
|
||||||
'lr_scheduler': lr_scheduler,
|
'lr_scheduler': lr_scheduler,
|
||||||
'lr_warmup': lr_warmup,
|
'lr_warmup': lr_warmup,
|
||||||
'train_batch_size': train_batch_size,
|
'train_batch_size': train_batch_size,
|
||||||
@ -120,7 +118,6 @@ def save_configuration(
|
|||||||
'prior_loss_weight': prior_loss_weight,
|
'prior_loss_weight': prior_loss_weight,
|
||||||
'text_encoder_lr': text_encoder_lr,
|
'text_encoder_lr': text_encoder_lr,
|
||||||
'unet_lr': unet_lr,
|
'unet_lr': unet_lr,
|
||||||
'network_train': network_train,
|
|
||||||
'network_dim': network_dim
|
'network_dim': network_dim
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,7 +138,6 @@ def open_configuration(
|
|||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
learning_rate,
|
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
@ -164,7 +160,7 @@ def open_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
|
||||||
):
|
):
|
||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
@ -192,7 +188,6 @@ def open_configuration(
|
|||||||
my_data.get('reg_data_dir', reg_data_dir),
|
my_data.get('reg_data_dir', reg_data_dir),
|
||||||
my_data.get('output_dir', output_dir),
|
my_data.get('output_dir', output_dir),
|
||||||
my_data.get('max_resolution', max_resolution),
|
my_data.get('max_resolution', max_resolution),
|
||||||
my_data.get('learning_rate', learning_rate),
|
|
||||||
my_data.get('lr_scheduler', lr_scheduler),
|
my_data.get('lr_scheduler', lr_scheduler),
|
||||||
my_data.get('lr_warmup', lr_warmup),
|
my_data.get('lr_warmup', lr_warmup),
|
||||||
my_data.get('train_batch_size', train_batch_size),
|
my_data.get('train_batch_size', train_batch_size),
|
||||||
@ -220,7 +215,6 @@ def open_configuration(
|
|||||||
my_data.get('prior_loss_weight', prior_loss_weight),
|
my_data.get('prior_loss_weight', prior_loss_weight),
|
||||||
my_data.get('text_encoder_lr', text_encoder_lr),
|
my_data.get('text_encoder_lr', text_encoder_lr),
|
||||||
my_data.get('unet_lr', unet_lr),
|
my_data.get('unet_lr', unet_lr),
|
||||||
my_data.get('network_train', network_train),
|
|
||||||
my_data.get('network_dim', network_dim),
|
my_data.get('network_dim', network_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -234,7 +228,6 @@ def train_model(
|
|||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_resolution,
|
max_resolution,
|
||||||
learning_rate,
|
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
@ -257,7 +250,7 @@ def train_model(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
|
||||||
):
|
):
|
||||||
def save_inference_file(output_dir, v2, v_parameterization):
|
def save_inference_file(output_dir, v2, v_parameterization):
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
@ -294,6 +287,14 @@ def train_model(
|
|||||||
if output_dir == '':
|
if output_dir == '':
|
||||||
msgbox('Output folder path is missing')
|
msgbox('Output folder path is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If string is empty set string to 0.
|
||||||
|
if text_encoder_lr == '': text_encoder_lr = 0
|
||||||
|
if unet_lr == '': unet_lr = 0
|
||||||
|
|
||||||
|
if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
|
||||||
|
msgbox('At least one Learning Rate value for "Text encoder" or "Unet" need to be provided')
|
||||||
|
return
|
||||||
|
|
||||||
# Get a list of all subfolders in train_data_dir
|
# Get a list of all subfolders in train_data_dir
|
||||||
subfolders = [
|
subfolders = [
|
||||||
@ -394,7 +395,7 @@ def train_model(
|
|||||||
run_cmd += f' --resolution={max_resolution}'
|
run_cmd += f' --resolution={max_resolution}'
|
||||||
run_cmd += f' --output_dir={output_dir}'
|
run_cmd += f' --output_dir={output_dir}'
|
||||||
run_cmd += f' --train_batch_size={train_batch_size}'
|
run_cmd += f' --train_batch_size={train_batch_size}'
|
||||||
run_cmd += f' --learning_rate={learning_rate}'
|
# run_cmd += f' --learning_rate={learning_rate}'
|
||||||
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
||||||
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
||||||
run_cmd += f' --max_train_steps={max_train_steps}'
|
run_cmd += f' --max_train_steps={max_train_steps}'
|
||||||
@ -418,14 +419,18 @@ def train_model(
|
|||||||
if not float(prior_loss_weight) == 1.0:
|
if not float(prior_loss_weight) == 1.0:
|
||||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||||
run_cmd += f' --network_module=networks.lora'
|
run_cmd += f' --network_module=networks.lora'
|
||||||
if not text_encoder_lr == '':
|
if not float(text_encoder_lr) == 0:
|
||||||
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
|
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
|
||||||
if not unet_lr == '':
|
else:
|
||||||
run_cmd += f' --unet_lr={unet_lr}'
|
|
||||||
if network_train == 'Text encoder only':
|
|
||||||
run_cmd += f' --network_train_text_encoder_only'
|
|
||||||
elif network_train == 'Unet only':
|
|
||||||
run_cmd += f' --network_train_unet_only'
|
run_cmd += f' --network_train_unet_only'
|
||||||
|
if not float(unet_lr) == 0:
|
||||||
|
run_cmd += f' --unet_lr={unet_lr}'
|
||||||
|
else:
|
||||||
|
run_cmd += f' --network_train_text_encoder_only'
|
||||||
|
# if network_train == 'Text encoder only':
|
||||||
|
# run_cmd += f' --network_train_text_encoder_only'
|
||||||
|
# elif network_train == 'Unet only':
|
||||||
|
# run_cmd += f' --network_train_unet_only'
|
||||||
run_cmd += f' --network_dim={network_dim}'
|
run_cmd += f' --network_dim={network_dim}'
|
||||||
|
|
||||||
|
|
||||||
@ -695,7 +700,7 @@ def lora_tab(
|
|||||||
)
|
)
|
||||||
with gr.Tab('Training parameters'):
|
with gr.Tab('Training parameters'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-4)
|
# learning_rate_input = gr.Textbox(label='Learning rate', value=1e-4, visible=False)
|
||||||
lr_scheduler_input = gr.Dropdown(
|
lr_scheduler_input = gr.Dropdown(
|
||||||
label='LR Scheduler',
|
label='LR Scheduler',
|
||||||
choices=[
|
choices=[
|
||||||
@ -712,16 +717,16 @@ def lora_tab(
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_encoder_lr = gr.Textbox(label='Text Encoder learning rate', value=1e-6, placeholder='Optional')
|
text_encoder_lr = gr.Textbox(label='Text Encoder learning rate', value=1e-6, placeholder='Optional')
|
||||||
unet_lr = gr.Textbox(label='Unet learning rate', value=1e-4, placeholder='Optional')
|
unet_lr = gr.Textbox(label='Unet learning rate', value=1e-4, placeholder='Optional')
|
||||||
network_train =gr.Dropdown(
|
# network_train = gr.Dropdown(
|
||||||
label='Network to train',
|
# label='Network to train',
|
||||||
choices=[
|
# choices=[
|
||||||
'Text encoder and Unet',
|
# 'Text encoder and Unet',
|
||||||
'Text encoder only',
|
# 'Text encoder only',
|
||||||
'Unet only',
|
# 'Unet only',
|
||||||
],
|
# ],
|
||||||
value='Text encoder and Unet',
|
# value='Text encoder and Unet',
|
||||||
interactive=True
|
# interactive=True
|
||||||
)
|
# )
|
||||||
network_dim = gr.Slider(
|
network_dim = gr.Slider(
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=32,
|
maximum=32,
|
||||||
@ -846,7 +851,7 @@ def lora_tab(
|
|||||||
reg_data_dir_input,
|
reg_data_dir_input,
|
||||||
output_dir_input,
|
output_dir_input,
|
||||||
max_resolution_input,
|
max_resolution_input,
|
||||||
learning_rate_input,
|
# learning_rate_input,
|
||||||
lr_scheduler_input,
|
lr_scheduler_input,
|
||||||
lr_warmup_input,
|
lr_warmup_input,
|
||||||
train_batch_size_input,
|
train_batch_size_input,
|
||||||
@ -869,7 +874,7 @@ def lora_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_train, network_dim
|
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user