From f9863e3950d6e217e33d007c40c599d47c826f7c Mon Sep 17 00:00:00 2001 From: bmaltais Date: Thu, 16 Feb 2023 19:33:46 -0500 Subject: [PATCH] add dadapation to other trainers --- fine_tune.py | 31 ++++++++++++++++++++++++++----- tools/crop_images_to_n_buckets.py | 2 +- train_db.py | 23 ++++++++++++++++++----- train_network.py | 14 +++++++++++--- 4 files changed, 56 insertions(+), 14 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 5292153..12fb91f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -14,6 +14,9 @@ from diffusers import DDPMScheduler import library.train_util as train_util +import torch.optim as optim +import dadaptation + def collate_fn(examples): return examples[0] @@ -162,7 +165,9 @@ def train(args): optimizer_class = torch.optim.AdamW # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) + # optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) + print('enable dadatation.') + optimizer = dadaptation.DAdaptAdam(params_to_optimize, lr=1.0, decouple=True, weight_decay=0) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -176,8 +181,20 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + # lr_scheduler = diffusers.optimization.get_scheduler( + # args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + + # For Adam + # lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, + # lr_lambda=[lambda epoch: 1], + # last_epoch=-1, + # verbose=False) + + # For SGD optim + lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, + lr_lambda=[lambda epoch: 1], + last_epoch=-1, + verbose=True) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -293,12 +310,16 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + # logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + # accelerator.log(logs, step=global_step) + logs = {"loss": current_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} accelerator.log(logs, step=global_step) loss_total += current_loss avr_loss = loss_total / (step+1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + # logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + # progress_bar.set_postfix(**logs) + logs = {"avg_loss": avr_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: diff --git a/tools/crop_images_to_n_buckets.py b/tools/crop_images_to_n_buckets.py index dff7825..688b42b 100644 --- a/tools/crop_images_to_n_buckets.py +++ b/tools/crop_images_to_n_buckets.py @@ -20,7 +20,7 @@ def sort_images_by_aspect_ratio(path): """Sort all images in a folder by aspect ratio""" images = [] for filename in os.listdir(path): - if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"): + if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png") or filename.endswith(".webp"): img_path = os.path.join(path, filename) images.append((img_path, aspect_ratio(img_path))) # sort the list of tuples based on the aspect ratio diff --git a/train_db.py b/train_db.py index c210767..aacfcc8 100644 --- a/train_db.py +++ b/train_db.py @@ -17,6 +17,8 @@ from diffusers import DDPMScheduler import library.train_util as train_util from library.train_util import DreamBoothDataset +import torch.optim as optim +import dadaptation def collate_fn(examples): return examples[0] @@ -133,13 +135,16 @@ def train(args): trainable_params = unet.parameters() # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - optimizer = optimizer_class(trainable_params, lr=args.learning_rate) + # optimizer = optimizer_class(trainable_params, lr=args.learning_rate) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + print('enable dadatation.') + optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0, d0=0.00000001) + # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -150,8 +155,14 @@ def train(args): args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) + # lr_scheduler = diffusers.optimization.get_scheduler( + # args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) + + # For Adam + lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, + lr_lambda=[lambda epoch: 1], + last_epoch=-1, + verbose=False) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -288,12 +299,14 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + # logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + logs = {"loss": current_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} accelerator.log(logs, step=global_step) loss_total += current_loss avr_loss = loss_total / (step+1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + # logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"avg_loss": avr_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: diff --git a/train_network.py b/train_network.py index e729793..824c298 100644 --- a/train_network.py +++ b/train_network.py @@ -222,7 +222,7 @@ def train(args): print('enable dadatation.') optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0) # optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6) - # optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-8,) + # optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-6) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -242,10 +242,18 @@ def train(args): # num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, # num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # override lr_scheduler. + + # For Adam lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, - lr_lambda=[lambda epoch: 0.5, lambda epoch: 1], + lr_lambda=[lambda epoch: 0.25, lambda epoch: 1], last_epoch=-1, verbose=False) + + # For SGD optim + # lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, + # lr_lambda=[lambda epoch: 1, lambda epoch: 0.5], + # last_epoch=-1, + # verbose=False) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -462,7 +470,7 @@ def train(args): avr_loss = loss_total / (step+1) # logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} # progress_bar.set_postfix(**logs) - logs_str = f"loss: {avr_loss:.3f}, dlr: {optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']:.2e}" + logs_str = f"loss: {avr_loss:.3f}, dlr0: {optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']:.2e}, dlr1: {optimizer.param_groups[1]['d']*optimizer.param_groups[1]['lr']:.2e}" progress_bar.set_postfix_str(logs_str) if args.logging_dir is not None: