diff --git a/diffusers_fine_tuning/fine_tune.py b/diffusers_fine_tuning/fine_tune.py index 013e50d..1cc35f4 100644 --- a/diffusers_fine_tuning/fine_tune.py +++ b/diffusers_fine_tuning/fine_tune.py @@ -341,7 +341,7 @@ def train(args): # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( - "constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps) # acceleratorがなんかよろしくやってくれるらしい if fine_tuning: @@ -843,6 +843,7 @@ if __name__ == '__main__': help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") parser.add_argument("--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - + parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.") args = parser.parse_args() train(args) diff --git a/train_db_fixed.py b/train_db_fixed.py index 2037e78..488bfc5 100644 --- a/train_db_fixed.py +++ b/train_db_fixed.py @@ -1344,7 +1344,7 @@ def train(args): train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps) + lr_scheduler = diffusers.optimization.get_scheduler(args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps, num_warmup_steps=args.lr_warmup_steps) # acceleratorがなんかよろしくやってくれるらしい unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -1825,6 +1825,8 @@ if __name__ == '__main__': help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") parser.add_argument("--logging_dir", type=str, default=None, help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - + parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.") + args = parser.parse_args() train(args)