This commit is contained in:
bmaltais 2023-03-21 20:03:53 -04:00
parent ccae80186a
commit 77ccc53046
4 changed files with 23 additions and 0 deletions

View File

@ -0,0 +1,14 @@
import torch
def apply_snr_weight(loss, noisy_latents, latents, gamma):
gamma = gamma
if gamma:
sigma = torch.sub(noisy_latents, latents)
zeros = torch.zeros_like(sigma)
alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3])
sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3])
snr = torch.div(alpha_mean_sq, sigma_mean_sq)
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float()
loss = loss * snr_weight
return loss

View File

@ -1928,6 +1928,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
) )
parser.add_argument("--min_snr_gamma", type=float, default=5, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.")
def verify_training_args(args: argparse.Namespace): def verify_training_args(args: argparse.Namespace):
if args.v_parameterization and not args.v2: if args.v_parameterization and not args.v2:

View File

@ -10,6 +10,7 @@ import os
import toml import toml
from tqdm import tqdm from tqdm import tqdm
from library.custom_train_functions import apply_snr_weight
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers import diffusers
@ -291,6 +292,8 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)

View File

@ -15,6 +15,7 @@ from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
from library.custom_train_functions import apply_snr_weight
from library.train_util import ( from library.train_util import (
DreamBoothDataset, DreamBoothDataset,
) )
@ -549,6 +550,8 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)