Add dadapation to other trainers

This commit is contained in:
bmaltais 2023-02-16 19:33:33 -05:00
parent 6129c7dd40
commit 655f885cf4

View File

@ -13,34 +13,41 @@ from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
from library.train_util import DreamBoothDataset, FineTuningDataset from library.train_util import DreamBoothDataset, FineTuningDataset
import torch.optim as optim
import dadaptation
# imagenet_templates_small = [
# "a photo of a {}",
# "a rendering of a {}",
# "a cropped photo of the {}",
# "the photo of a {}",
# "a photo of a clean {}",
# "a photo of a dirty {}",
# "a dark photo of the {}",
# "a photo of my {}",
# "a photo of the cool {}",
# "a close-up photo of a {}",
# "a bright photo of the {}",
# "a cropped photo of a {}",
# "a photo of the {}",
# "a good photo of the {}",
# "a photo of one {}",
# "a close-up photo of the {}",
# "a rendition of the {}",
# "a photo of the clean {}",
# "a rendition of a {}",
# "a photo of a nice {}",
# "a good photo of a {}",
# "a photo of the nice {}",
# "a photo of the small {}",
# "a photo of the weird {}",
# "a photo of the large {}",
# "a photo of a cool {}",
# "a photo of a small {}",
# ]
imagenet_templates_small = [ imagenet_templates_small = [
"a photo of a {}", "{}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
] ]
imagenet_style_templates_small = [ imagenet_style_templates_small = [
@ -213,7 +220,12 @@ def train(args):
trainable_params = text_encoder.get_input_embeddings().parameters() trainable_params = text_encoder.get_input_embeddings().parameters()
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class(trainable_params, lr=args.learning_rate) # optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
print('enable dadapation.')
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0)
# optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
# optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
@ -227,8 +239,20 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler( # lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) # args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# For Adam
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
lr_lambda=[lambda epoch: 1],
last_epoch=-1,
verbose=False)
# For SGD optim
# lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
# lr_lambda=[lambda epoch: 1],
# last_epoch=-1,
# verbose=False)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@ -366,12 +390,16 @@ def train(args):
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} #logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss, "dlr0": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']}
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
loss_total += current_loss loss_total += current_loss
avr_loss = loss_total / (step+1) avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} # logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
logs = {"loss": avr_loss, "dlr0": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps: