diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py new file mode 100644 index 0000000..7d42c5d --- /dev/null +++ b/library/custom_train_functions.py @@ -0,0 +1,14 @@ +import torch + +def apply_snr_weight(loss, noisy_latents, latents, gamma): + gamma = gamma + if gamma: + sigma = torch.sub(noisy_latents, latents) + zeros = torch.zeros_like(sigma) + alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) + sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) + snr = torch.div(alpha_mean_sq, sigma_mean_sq) + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() + loss = loss * snr_weight + return loss diff --git a/library/train_util.py b/library/train_util.py index 7d31182..85a059d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1927,6 +1927,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) + + parser.add_argument("--min_snr_gamma", type=float, default=5, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") + def verify_training_args(args: argparse.Namespace): diff --git a/train_db.py b/train_db.py index 81aeda1..51fce88 100644 --- a/train_db.py +++ b/train_db.py @@ -10,6 +10,7 @@ import os import toml from tqdm import tqdm +from library.custom_train_functions import apply_snr_weight import torch from accelerate.utils import set_seed import diffusers @@ -290,6 +291,8 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + + loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 7f910df..5127796 100644 --- a/train_network.py +++ b/train_network.py @@ -15,6 +15,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler import library.train_util as train_util +from library.custom_train_functions import apply_snr_weight from library.train_util import ( DreamBoothDataset, ) @@ -548,6 +549,8 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + + loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし