add support for color and flip augmentation to "Dreambooth LoRA"
This commit is contained in:
parent
bfa590b313
commit
1d460a09fd
@ -30,6 +30,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i
|
||||
|
||||
## Change history
|
||||
|
||||
* 2023/01/01 (v19.2) update:
|
||||
- add support for color and flip augmentation to "Dreambooth LoRA"
|
||||
* 2023/01/01 (v19.1) update:
|
||||
- merge kohys_ss upstream code updates
|
||||
- rework Dreambooth LoRA GUI
|
||||
|
@ -17,6 +17,7 @@ from library.common_gui import (
|
||||
get_file_path,
|
||||
get_any_file_path,
|
||||
get_saveasfile_path,
|
||||
color_aug_changed
|
||||
)
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
@ -65,7 +66,7 @@ def save_configuration(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, color_aug, flip_aug
|
||||
):
|
||||
original_file_path = file_path
|
||||
|
||||
@ -118,6 +119,8 @@ def save_configuration(
|
||||
'save_state': save_state,
|
||||
'resume': resume,
|
||||
'prior_loss_weight': prior_loss_weight,
|
||||
'color_aug': color_aug,
|
||||
'flip_aug': flip_aug,
|
||||
}
|
||||
|
||||
# Save the data to the selected file
|
||||
@ -160,7 +163,7 @@ def open_configuration(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, color_aug, flip_aug
|
||||
):
|
||||
|
||||
original_file_path = file_path
|
||||
@ -214,6 +217,8 @@ def open_configuration(
|
||||
my_data.get('save_state', save_state),
|
||||
my_data.get('resume', resume),
|
||||
my_data.get('prior_loss_weight', prior_loss_weight),
|
||||
my_data.get('color_aug', color_aug),
|
||||
my_data.get('flip_aug', flip_aug),
|
||||
)
|
||||
|
||||
|
||||
@ -249,7 +254,7 @@ def train_model(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, color_aug, flip_aug
|
||||
):
|
||||
def save_inference_file(output_dir, v2, v_parameterization):
|
||||
# Copy inference model for v2 if required
|
||||
@ -377,6 +382,10 @@ def train_model(
|
||||
run_cmd += ' --shuffle_caption'
|
||||
if save_state:
|
||||
run_cmd += ' --save_state'
|
||||
if color_aug:
|
||||
run_cmd += ' --color_aug'
|
||||
if flip_aug:
|
||||
run_cmd += ' --flip_aug'
|
||||
run_cmd += (
|
||||
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
||||
)
|
||||
@ -505,7 +514,7 @@ def UI(username, password):
|
||||
interface.launch(auth=(username, password))
|
||||
else:
|
||||
interface.launch()
|
||||
|
||||
|
||||
|
||||
def dreambooth_tab(
|
||||
train_data_dir_input=gr.Textbox(),
|
||||
@ -762,6 +771,13 @@ def dreambooth_tab(
|
||||
save_state = gr.Checkbox(
|
||||
label='Save training state', value=False
|
||||
)
|
||||
color_aug = gr.Checkbox(
|
||||
label='Color augmentation', value=False
|
||||
)
|
||||
flip_aug = gr.Checkbox(
|
||||
label='Flip augmentation', value=False
|
||||
)
|
||||
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
|
||||
with gr.Row():
|
||||
resume = gr.Textbox(
|
||||
label='Resume from saved training state',
|
||||
@ -819,7 +835,7 @@ def dreambooth_tab(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, color_aug, flip_aug
|
||||
],
|
||||
outputs=[
|
||||
config_file_name,
|
||||
@ -854,7 +870,7 @@ def dreambooth_tab(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, color_aug, flip_aug
|
||||
],
|
||||
)
|
||||
|
||||
@ -894,7 +910,7 @@ def dreambooth_tab(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, color_aug, flip_aug
|
||||
],
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
@ -935,7 +951,7 @@ def dreambooth_tab(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, color_aug, flip_aug
|
||||
],
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
@ -974,7 +990,7 @@ def dreambooth_tab(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight,
|
||||
prior_loss_weight, color_aug, flip_aug
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
from tkinter import filedialog, Tk
|
||||
import os
|
||||
import gradio as gr
|
||||
from easygui import msgbox
|
||||
|
||||
|
||||
def get_file_path(file_path='', defaultextension='.json'):
|
||||
@ -107,3 +109,10 @@ def add_pre_postfix(
|
||||
f.seek(0, 0)
|
||||
f.write(f'{prefix}{content}{postfix}')
|
||||
f.close()
|
||||
|
||||
def color_aug_changed(color_aug):
|
||||
if color_aug:
|
||||
msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...')
|
||||
return gr.Checkbox.update(value=False, interactive=False)
|
||||
else:
|
||||
return gr.Checkbox.update(value=True, interactive=True)
|
24
lora_gui.py
24
lora_gui.py
@ -17,6 +17,7 @@ from library.common_gui import (
|
||||
get_file_path,
|
||||
get_any_file_path,
|
||||
get_saveasfile_path,
|
||||
color_aug_changed
|
||||
)
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
@ -64,7 +65,7 @@ def save_configuration(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
||||
):
|
||||
original_file_path = file_path
|
||||
|
||||
@ -120,6 +121,8 @@ def save_configuration(
|
||||
'unet_lr': unet_lr,
|
||||
'network_dim': network_dim,
|
||||
'lora_network_weights': lora_network_weights,
|
||||
'color_aug': color_aug,
|
||||
'flip_aug': flip_aug,
|
||||
}
|
||||
|
||||
# Save the data to the selected file
|
||||
@ -161,7 +164,7 @@ def open_configuration(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
||||
):
|
||||
|
||||
original_file_path = file_path
|
||||
@ -218,6 +221,8 @@ def open_configuration(
|
||||
my_data.get('unet_lr', unet_lr),
|
||||
my_data.get('network_dim', network_dim),
|
||||
my_data.get('lora_network_weights', lora_network_weights),
|
||||
my_data.get('color_aug', color_aug),
|
||||
my_data.get('flip_aug', flip_aug),
|
||||
)
|
||||
|
||||
|
||||
@ -252,7 +257,7 @@ def train_model(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
||||
):
|
||||
def save_inference_file(output_dir, v2, v_parameterization):
|
||||
# Copy inference model for v2 if required
|
||||
@ -388,6 +393,10 @@ def train_model(
|
||||
run_cmd += ' --shuffle_caption'
|
||||
if save_state:
|
||||
run_cmd += ' --save_state'
|
||||
if color_aug:
|
||||
run_cmd += ' --color_aug'
|
||||
if flip_aug:
|
||||
run_cmd += ' --flip_aug'
|
||||
run_cmd += (
|
||||
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
||||
)
|
||||
@ -825,6 +834,13 @@ def lora_tab(
|
||||
save_state = gr.Checkbox(
|
||||
label='Save training state', value=False
|
||||
)
|
||||
color_aug = gr.Checkbox(
|
||||
label='Color augmentation', value=False
|
||||
)
|
||||
flip_aug = gr.Checkbox(
|
||||
label='Flip augmentation', value=False
|
||||
)
|
||||
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
|
||||
with gr.Row():
|
||||
resume = gr.Textbox(
|
||||
label='Resume from saved training state',
|
||||
@ -879,7 +895,7 @@ def lora_tab(
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
resume,
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
Loading…
Reference in New Issue
Block a user