This commit is contained in:
bmaltais 2023-03-21 20:03:53 -04:00
parent ccae80186a
commit 77ccc53046
4 changed files with 23 additions and 0 deletions

View File

@ -0,0 +1,14 @@
import torch
def apply_snr_weight(loss, noisy_latents, latents, gamma):
gamma = gamma
if gamma:
sigma = torch.sub(noisy_latents, latents)
zeros = torch.zeros_like(sigma)
alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3])
sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3])
snr = torch.div(alpha_mean_sq, sigma_mean_sq)
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float()
loss = loss * snr_weight
return loss

View File

@ -1927,6 +1927,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)
parser.add_argument("--min_snr_gamma", type=float, default=5, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.")
def verify_training_args(args: argparse.Namespace):

View File

@ -10,6 +10,7 @@ import os
import toml
from tqdm import tqdm
from library.custom_train_functions import apply_snr_weight
import torch
from accelerate.utils import set_seed
import diffusers
@ -290,6 +291,8 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

View File

@ -15,6 +15,7 @@ from accelerate.utils import set_seed
from diffusers import DDPMScheduler
import library.train_util as train_util
from library.custom_train_functions import apply_snr_weight
from library.train_util import (
DreamBoothDataset,
)
@ -548,6 +549,8 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし