add support for color and flip augmentation to "Dreambooth LoRA"
This commit is contained in:
parent
bfa590b313
commit
1d460a09fd
@ -30,6 +30,8 @@ Once you have created the LoRA network you can generate images via auto1111 by i
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 2023/01/01 (v19.2) update:
|
||||||
|
- add support for color and flip augmentation to "Dreambooth LoRA"
|
||||||
* 2023/01/01 (v19.1) update:
|
* 2023/01/01 (v19.1) update:
|
||||||
- merge kohys_ss upstream code updates
|
- merge kohys_ss upstream code updates
|
||||||
- rework Dreambooth LoRA GUI
|
- rework Dreambooth LoRA GUI
|
||||||
|
@ -17,6 +17,7 @@ from library.common_gui import (
|
|||||||
get_file_path,
|
get_file_path,
|
||||||
get_any_file_path,
|
get_any_file_path,
|
||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
|
color_aug_changed
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
@ -65,7 +66,7 @@ def save_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, color_aug, flip_aug
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
@ -118,6 +119,8 @@ def save_configuration(
|
|||||||
'save_state': save_state,
|
'save_state': save_state,
|
||||||
'resume': resume,
|
'resume': resume,
|
||||||
'prior_loss_weight': prior_loss_weight,
|
'prior_loss_weight': prior_loss_weight,
|
||||||
|
'color_aug': color_aug,
|
||||||
|
'flip_aug': flip_aug,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save the data to the selected file
|
# Save the data to the selected file
|
||||||
@ -160,7 +163,7 @@ def open_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, color_aug, flip_aug
|
||||||
):
|
):
|
||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
@ -214,6 +217,8 @@ def open_configuration(
|
|||||||
my_data.get('save_state', save_state),
|
my_data.get('save_state', save_state),
|
||||||
my_data.get('resume', resume),
|
my_data.get('resume', resume),
|
||||||
my_data.get('prior_loss_weight', prior_loss_weight),
|
my_data.get('prior_loss_weight', prior_loss_weight),
|
||||||
|
my_data.get('color_aug', color_aug),
|
||||||
|
my_data.get('flip_aug', flip_aug),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -249,7 +254,7 @@ def train_model(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, color_aug, flip_aug
|
||||||
):
|
):
|
||||||
def save_inference_file(output_dir, v2, v_parameterization):
|
def save_inference_file(output_dir, v2, v_parameterization):
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
@ -377,6 +382,10 @@ def train_model(
|
|||||||
run_cmd += ' --shuffle_caption'
|
run_cmd += ' --shuffle_caption'
|
||||||
if save_state:
|
if save_state:
|
||||||
run_cmd += ' --save_state'
|
run_cmd += ' --save_state'
|
||||||
|
if color_aug:
|
||||||
|
run_cmd += ' --color_aug'
|
||||||
|
if flip_aug:
|
||||||
|
run_cmd += ' --flip_aug'
|
||||||
run_cmd += (
|
run_cmd += (
|
||||||
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
||||||
)
|
)
|
||||||
@ -762,6 +771,13 @@ def dreambooth_tab(
|
|||||||
save_state = gr.Checkbox(
|
save_state = gr.Checkbox(
|
||||||
label='Save training state', value=False
|
label='Save training state', value=False
|
||||||
)
|
)
|
||||||
|
color_aug = gr.Checkbox(
|
||||||
|
label='Color augmentation', value=False
|
||||||
|
)
|
||||||
|
flip_aug = gr.Checkbox(
|
||||||
|
label='Flip augmentation', value=False
|
||||||
|
)
|
||||||
|
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resume = gr.Textbox(
|
resume = gr.Textbox(
|
||||||
label='Resume from saved training state',
|
label='Resume from saved training state',
|
||||||
@ -819,7 +835,7 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, color_aug, flip_aug
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
config_file_name,
|
config_file_name,
|
||||||
@ -854,7 +870,7 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, color_aug, flip_aug
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -894,7 +910,7 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, color_aug, flip_aug
|
||||||
],
|
],
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
@ -935,7 +951,7 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, color_aug, flip_aug
|
||||||
],
|
],
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
@ -974,7 +990,7 @@ def dreambooth_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight,
|
prior_loss_weight, color_aug, flip_aug
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from tkinter import filedialog, Tk
|
from tkinter import filedialog, Tk
|
||||||
import os
|
import os
|
||||||
|
import gradio as gr
|
||||||
|
from easygui import msgbox
|
||||||
|
|
||||||
|
|
||||||
def get_file_path(file_path='', defaultextension='.json'):
|
def get_file_path(file_path='', defaultextension='.json'):
|
||||||
@ -107,3 +109,10 @@ def add_pre_postfix(
|
|||||||
f.seek(0, 0)
|
f.seek(0, 0)
|
||||||
f.write(f'{prefix}{content}{postfix}')
|
f.write(f'{prefix}{content}{postfix}')
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
def color_aug_changed(color_aug):
|
||||||
|
if color_aug:
|
||||||
|
msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...')
|
||||||
|
return gr.Checkbox.update(value=False, interactive=False)
|
||||||
|
else:
|
||||||
|
return gr.Checkbox.update(value=True, interactive=True)
|
24
lora_gui.py
24
lora_gui.py
@ -17,6 +17,7 @@ from library.common_gui import (
|
|||||||
get_file_path,
|
get_file_path,
|
||||||
get_any_file_path,
|
get_any_file_path,
|
||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
|
color_aug_changed
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
@ -64,7 +65,7 @@ def save_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
@ -120,6 +121,8 @@ def save_configuration(
|
|||||||
'unet_lr': unet_lr,
|
'unet_lr': unet_lr,
|
||||||
'network_dim': network_dim,
|
'network_dim': network_dim,
|
||||||
'lora_network_weights': lora_network_weights,
|
'lora_network_weights': lora_network_weights,
|
||||||
|
'color_aug': color_aug,
|
||||||
|
'flip_aug': flip_aug,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save the data to the selected file
|
# Save the data to the selected file
|
||||||
@ -161,7 +164,7 @@ def open_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
||||||
):
|
):
|
||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
@ -218,6 +221,8 @@ def open_configuration(
|
|||||||
my_data.get('unet_lr', unet_lr),
|
my_data.get('unet_lr', unet_lr),
|
||||||
my_data.get('network_dim', network_dim),
|
my_data.get('network_dim', network_dim),
|
||||||
my_data.get('lora_network_weights', lora_network_weights),
|
my_data.get('lora_network_weights', lora_network_weights),
|
||||||
|
my_data.get('color_aug', color_aug),
|
||||||
|
my_data.get('flip_aug', flip_aug),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -252,7 +257,7 @@ def train_model(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
||||||
):
|
):
|
||||||
def save_inference_file(output_dir, v2, v_parameterization):
|
def save_inference_file(output_dir, v2, v_parameterization):
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
@ -388,6 +393,10 @@ def train_model(
|
|||||||
run_cmd += ' --shuffle_caption'
|
run_cmd += ' --shuffle_caption'
|
||||||
if save_state:
|
if save_state:
|
||||||
run_cmd += ' --save_state'
|
run_cmd += ' --save_state'
|
||||||
|
if color_aug:
|
||||||
|
run_cmd += ' --color_aug'
|
||||||
|
if flip_aug:
|
||||||
|
run_cmd += ' --flip_aug'
|
||||||
run_cmd += (
|
run_cmd += (
|
||||||
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
|
||||||
)
|
)
|
||||||
@ -825,6 +834,13 @@ def lora_tab(
|
|||||||
save_state = gr.Checkbox(
|
save_state = gr.Checkbox(
|
||||||
label='Save training state', value=False
|
label='Save training state', value=False
|
||||||
)
|
)
|
||||||
|
color_aug = gr.Checkbox(
|
||||||
|
label='Color augmentation', value=False
|
||||||
|
)
|
||||||
|
flip_aug = gr.Checkbox(
|
||||||
|
label='Flip augmentation', value=False
|
||||||
|
)
|
||||||
|
color_aug.change(color_aug_changed, inputs=[color_aug], outputs=[cache_latent_input])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resume = gr.Textbox(
|
resume = gr.Textbox(
|
||||||
label='Resume from saved training state',
|
label='Resume from saved training state',
|
||||||
@ -879,7 +895,7 @@ def lora_tab(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
resume,
|
resume,
|
||||||
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
|
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights, color_aug, flip_aug
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
Loading…
Reference in New Issue
Block a user