Fix issue where dadaptation code was pushed by mistake
This commit is contained in:
parent
1807c548b5
commit
34ab8448fb
31
fine_tune.py
31
fine_tune.py
@ -14,9 +14,6 @@ from diffusers import DDPMScheduler
|
|||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
|
||||||
import torch.optim as optim
|
|
||||||
import dadaptation
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
return examples[0]
|
return examples[0]
|
||||||
@ -172,9 +169,7 @@ def train(args):
|
|||||||
optimizer_class = torch.optim.AdamW
|
optimizer_class = torch.optim.AdamW
|
||||||
|
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||||
# optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
||||||
print('enable dadatation.')
|
|
||||||
optimizer = dadaptation.DAdaptAdam(params_to_optimize, lr=1.0, decouple=True, weight_decay=0)
|
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@ -188,20 +183,8 @@ def train(args):
|
|||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
# args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
# For Adam
|
|
||||||
# lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
|
||||||
# lr_lambda=[lambda epoch: 1],
|
|
||||||
# last_epoch=-1,
|
|
||||||
# verbose=False)
|
|
||||||
|
|
||||||
# For SGD optim
|
|
||||||
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
|
||||||
lr_lambda=[lambda epoch: 1],
|
|
||||||
last_epoch=-1,
|
|
||||||
verbose=True)
|
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@ -320,16 +303,12 @@ def train(args):
|
|||||||
|
|
||||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
# logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
# accelerator.log(logs, step=global_step)
|
|
||||||
logs = {"loss": current_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']}
|
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
# progress_bar.set_postfix(**logs)
|
|
||||||
logs = {"avg_loss": avr_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
|
@ -13,41 +13,34 @@ from diffusers import DDPMScheduler
|
|||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.train_util import DreamBoothDataset, FineTuningDataset
|
from library.train_util import DreamBoothDataset, FineTuningDataset
|
||||||
|
|
||||||
import torch.optim as optim
|
|
||||||
import dadaptation
|
|
||||||
|
|
||||||
# imagenet_templates_small = [
|
|
||||||
# "a photo of a {}",
|
|
||||||
# "a rendering of a {}",
|
|
||||||
# "a cropped photo of the {}",
|
|
||||||
# "the photo of a {}",
|
|
||||||
# "a photo of a clean {}",
|
|
||||||
# "a photo of a dirty {}",
|
|
||||||
# "a dark photo of the {}",
|
|
||||||
# "a photo of my {}",
|
|
||||||
# "a photo of the cool {}",
|
|
||||||
# "a close-up photo of a {}",
|
|
||||||
# "a bright photo of the {}",
|
|
||||||
# "a cropped photo of a {}",
|
|
||||||
# "a photo of the {}",
|
|
||||||
# "a good photo of the {}",
|
|
||||||
# "a photo of one {}",
|
|
||||||
# "a close-up photo of the {}",
|
|
||||||
# "a rendition of the {}",
|
|
||||||
# "a photo of the clean {}",
|
|
||||||
# "a rendition of a {}",
|
|
||||||
# "a photo of a nice {}",
|
|
||||||
# "a good photo of a {}",
|
|
||||||
# "a photo of the nice {}",
|
|
||||||
# "a photo of the small {}",
|
|
||||||
# "a photo of the weird {}",
|
|
||||||
# "a photo of the large {}",
|
|
||||||
# "a photo of a cool {}",
|
|
||||||
# "a photo of a small {}",
|
|
||||||
# ]
|
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
"{}",
|
"a photo of a {}",
|
||||||
|
"a rendering of a {}",
|
||||||
|
"a cropped photo of the {}",
|
||||||
|
"the photo of a {}",
|
||||||
|
"a photo of a clean {}",
|
||||||
|
"a photo of a dirty {}",
|
||||||
|
"a dark photo of the {}",
|
||||||
|
"a photo of my {}",
|
||||||
|
"a photo of the cool {}",
|
||||||
|
"a close-up photo of a {}",
|
||||||
|
"a bright photo of the {}",
|
||||||
|
"a cropped photo of a {}",
|
||||||
|
"a photo of the {}",
|
||||||
|
"a good photo of the {}",
|
||||||
|
"a photo of one {}",
|
||||||
|
"a close-up photo of the {}",
|
||||||
|
"a rendition of the {}",
|
||||||
|
"a photo of the clean {}",
|
||||||
|
"a rendition of a {}",
|
||||||
|
"a photo of a nice {}",
|
||||||
|
"a good photo of a {}",
|
||||||
|
"a photo of the nice {}",
|
||||||
|
"a photo of the small {}",
|
||||||
|
"a photo of the weird {}",
|
||||||
|
"a photo of the large {}",
|
||||||
|
"a photo of a cool {}",
|
||||||
|
"a photo of a small {}",
|
||||||
]
|
]
|
||||||
|
|
||||||
imagenet_style_templates_small = [
|
imagenet_style_templates_small = [
|
||||||
@ -227,12 +220,7 @@ def train(args):
|
|||||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||||
|
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||||
# optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||||
print('enable dadapation.')
|
|
||||||
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0)
|
|
||||||
# optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
|
|
||||||
# optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@ -246,20 +234,8 @@ def train(args):
|
|||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
# args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
# For Adam
|
|
||||||
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
|
||||||
lr_lambda=[lambda epoch: 1],
|
|
||||||
last_epoch=-1,
|
|
||||||
verbose=False)
|
|
||||||
|
|
||||||
# For SGD optim
|
|
||||||
# lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
|
||||||
# lr_lambda=[lambda epoch: 1],
|
|
||||||
# last_epoch=-1,
|
|
||||||
# verbose=False)
|
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
@ -400,16 +376,12 @@ def train(args):
|
|||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
#logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
|
||||||
avr_loss = loss_total / (step+1)
|
|
||||||
logs = {"loss": avr_loss, "dlr0": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']}
|
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
logs = {"loss": avr_loss, "dlr0": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
|
Loading…
Reference in New Issue
Block a user