add dadapation to other trainers
This commit is contained in:
parent
655f885cf4
commit
f9863e3950
31
fine_tune.py
31
fine_tune.py
@ -14,6 +14,9 @@ from diffusers import DDPMScheduler
|
|||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
|
||||||
|
import torch.optim as optim
|
||||||
|
import dadaptation
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
return examples[0]
|
return examples[0]
|
||||||
@ -162,7 +165,9 @@ def train(args):
|
|||||||
optimizer_class = torch.optim.AdamW
|
optimizer_class = torch.optim.AdamW
|
||||||
|
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||||
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
# optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
||||||
|
print('enable dadatation.')
|
||||||
|
optimizer = dadaptation.DAdaptAdam(params_to_optimize, lr=1.0, decouple=True, weight_decay=0)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@ -176,8 +181,20 @@ def train(args):
|
|||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
# args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
# For Adam
|
||||||
|
# lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
||||||
|
# lr_lambda=[lambda epoch: 1],
|
||||||
|
# last_epoch=-1,
|
||||||
|
# verbose=False)
|
||||||
|
|
||||||
|
# For SGD optim
|
||||||
|
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
||||||
|
lr_lambda=[lambda epoch: 1],
|
||||||
|
last_epoch=-1,
|
||||||
|
verbose=True)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@ -293,12 +310,16 @@ def train(args):
|
|||||||
|
|
||||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
# logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
# accelerator.log(logs, step=global_step)
|
||||||
|
logs = {"loss": current_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
# progress_bar.set_postfix(**logs)
|
||||||
|
logs = {"avg_loss": avr_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
|
@ -20,7 +20,7 @@ def sort_images_by_aspect_ratio(path):
|
|||||||
"""Sort all images in a folder by aspect ratio"""
|
"""Sort all images in a folder by aspect ratio"""
|
||||||
images = []
|
images = []
|
||||||
for filename in os.listdir(path):
|
for filename in os.listdir(path):
|
||||||
if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
|
if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png") or filename.endswith(".webp"):
|
||||||
img_path = os.path.join(path, filename)
|
img_path = os.path.join(path, filename)
|
||||||
images.append((img_path, aspect_ratio(img_path)))
|
images.append((img_path, aspect_ratio(img_path)))
|
||||||
# sort the list of tuples based on the aspect ratio
|
# sort the list of tuples based on the aspect ratio
|
||||||
|
23
train_db.py
23
train_db.py
@ -17,6 +17,8 @@ from diffusers import DDPMScheduler
|
|||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.train_util import DreamBoothDataset
|
from library.train_util import DreamBoothDataset
|
||||||
|
|
||||||
|
import torch.optim as optim
|
||||||
|
import dadaptation
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
return examples[0]
|
return examples[0]
|
||||||
@ -133,13 +135,16 @@ def train(args):
|
|||||||
trainable_params = unet.parameters()
|
trainable_params = unet.parameters()
|
||||||
|
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
# optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||||
|
print('enable dadatation.')
|
||||||
|
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0, d0=0.00000001)
|
||||||
|
|
||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
@ -150,8 +155,14 @@ def train(args):
|
|||||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
# args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||||
|
|
||||||
|
# For Adam
|
||||||
|
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
||||||
|
lr_lambda=[lambda epoch: 1],
|
||||||
|
last_epoch=-1,
|
||||||
|
verbose=False)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@ -288,12 +299,14 @@ def train(args):
|
|||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
# logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
logs = {"loss": current_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
logs = {"avg_loss": avr_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
|
@ -222,7 +222,7 @@ def train(args):
|
|||||||
print('enable dadatation.')
|
print('enable dadatation.')
|
||||||
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0)
|
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0)
|
||||||
# optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
|
# optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
|
||||||
# optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-8,)
|
# optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@ -242,10 +242,18 @@ def train(args):
|
|||||||
# num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
# num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||||
# num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
# num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||||
# override lr_scheduler.
|
# override lr_scheduler.
|
||||||
|
|
||||||
|
# For Adam
|
||||||
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
||||||
lr_lambda=[lambda epoch: 0.5, lambda epoch: 1],
|
lr_lambda=[lambda epoch: 0.25, lambda epoch: 1],
|
||||||
last_epoch=-1,
|
last_epoch=-1,
|
||||||
verbose=False)
|
verbose=False)
|
||||||
|
|
||||||
|
# For SGD optim
|
||||||
|
# lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
||||||
|
# lr_lambda=[lambda epoch: 1, lambda epoch: 0.5],
|
||||||
|
# last_epoch=-1,
|
||||||
|
# verbose=False)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@ -462,7 +470,7 @@ def train(args):
|
|||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
# progress_bar.set_postfix(**logs)
|
# progress_bar.set_postfix(**logs)
|
||||||
logs_str = f"loss: {avr_loss:.3f}, dlr: {optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']:.2e}"
|
logs_str = f"loss: {avr_loss:.3f}, dlr0: {optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']:.2e}, dlr1: {optimizer.param_groups[1]['d']*optimizer.param_groups[1]['lr']:.2e}"
|
||||||
progress_bar.set_postfix_str(logs_str)
|
progress_bar.set_postfix_str(logs_str)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user