Add support for lr_scheduler
This commit is contained in:
parent
2629617de7
commit
dd241f2142
@ -341,7 +341,7 @@ def train(args):
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
"constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if fine_tuning:
|
||||
@ -843,6 +843,7 @@ if __name__ == '__main__':
|
||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||
parser.add_argument("--debug_dataset", action="store_true",
|
||||
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
||||
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.")
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
@ -1344,7 +1344,7 @@ def train(args):
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps, num_warmup_steps=args.lr_warmup_steps)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@ -1825,6 +1825,8 @@ if __name__ == '__main__':
|
||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||
parser.add_argument("--logging_dir", type=str, default=None,
|
||||
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
||||
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.")
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
Loading…
Reference in New Issue
Block a user