min-snr
This commit is contained in:
parent
ccae80186a
commit
77ccc53046
14
library/custom_train_functions.py
Normal file
14
library/custom_train_functions.py
Normal file
@ -0,0 +1,14 @@
|
||||
import torch
|
||||
|
||||
def apply_snr_weight(loss, noisy_latents, latents, gamma):
|
||||
gamma = gamma
|
||||
if gamma:
|
||||
sigma = torch.sub(noisy_latents, latents)
|
||||
zeros = torch.zeros_like(sigma)
|
||||
alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3])
|
||||
sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3])
|
||||
snr = torch.div(alpha_mean_sq, sigma_mean_sq)
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float()
|
||||
loss = loss * snr_weight
|
||||
return loss
|
@ -1927,6 +1927,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
||||
)
|
||||
|
||||
parser.add_argument("--min_snr_gamma", type=float, default=5, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.")
|
||||
|
||||
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
|
@ -10,6 +10,7 @@ import os
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
@ -290,6 +291,8 @@ def train(args):
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
@ -15,6 +15,7 @@ from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.train_util import (
|
||||
DreamBoothDataset,
|
||||
)
|
||||
@ -548,6 +549,8 @@ def train(args):
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user