Add support for lr_scheduler

This commit is contained in:
bmaltais 2022-11-22 07:51:52 -05:00
parent 2629617de7
commit dd241f2142
2 changed files with 7 additions and 4 deletions

View File

@ -341,7 +341,7 @@ def train(args):
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
"constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps)
# acceleratorがなんかよろしくやってくれるらしい
if fine_tuning:
@ -843,6 +843,7 @@ if __name__ == '__main__':
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--debug_dataset", action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.")
args = parser.parse_args()
train(args)

View File

@ -1344,7 +1344,7 @@ def train(args):
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
lr_scheduler = diffusers.optimization.get_scheduler(args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps, num_warmup_steps=args.lr_warmup_steps)
# acceleratorがなんかよろしくやってくれるらしい
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@ -1825,6 +1825,8 @@ if __name__ == '__main__':
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.")
args = parser.parse_args()
train(args)