min-snr
This commit is contained in:
parent
ccae80186a
commit
77ccc53046
14
library/custom_train_functions.py
Normal file
14
library/custom_train_functions.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def apply_snr_weight(loss, noisy_latents, latents, gamma):
|
||||||
|
gamma = gamma
|
||||||
|
if gamma:
|
||||||
|
sigma = torch.sub(noisy_latents, latents)
|
||||||
|
zeros = torch.zeros_like(sigma)
|
||||||
|
alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3])
|
||||||
|
sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3])
|
||||||
|
snr = torch.div(alpha_mean_sq, sigma_mean_sq)
|
||||||
|
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||||
|
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float()
|
||||||
|
loss = loss * snr_weight
|
||||||
|
return loss
|
@ -1927,6 +1927,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--min_snr_gamma", type=float, default=5, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def verify_training_args(args: argparse.Namespace):
|
def verify_training_args(args: argparse.Namespace):
|
||||||
|
@ -10,6 +10,7 @@ import os
|
|||||||
import toml
|
import toml
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from library.custom_train_functions import apply_snr_weight
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
import diffusers
|
||||||
@ -290,6 +291,8 @@ def train(args):
|
|||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
|
loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ from accelerate.utils import set_seed
|
|||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
from library.custom_train_functions import apply_snr_weight
|
||||||
from library.train_util import (
|
from library.train_util import (
|
||||||
DreamBoothDataset,
|
DreamBoothDataset,
|
||||||
)
|
)
|
||||||
@ -548,6 +549,8 @@ def train(args):
|
|||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
|
loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user