diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4aa91ee..118c99f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -13,34 +13,41 @@ from diffusers import DDPMScheduler import library.train_util as train_util from library.train_util import DreamBoothDataset, FineTuningDataset +import torch.optim as optim +import dadaptation + +# imagenet_templates_small = [ +# "a photo of a {}", +# "a rendering of a {}", +# "a cropped photo of the {}", +# "the photo of a {}", +# "a photo of a clean {}", +# "a photo of a dirty {}", +# "a dark photo of the {}", +# "a photo of my {}", +# "a photo of the cool {}", +# "a close-up photo of a {}", +# "a bright photo of the {}", +# "a cropped photo of a {}", +# "a photo of the {}", +# "a good photo of the {}", +# "a photo of one {}", +# "a close-up photo of the {}", +# "a rendition of the {}", +# "a photo of the clean {}", +# "a rendition of a {}", +# "a photo of a nice {}", +# "a good photo of a {}", +# "a photo of the nice {}", +# "a photo of the small {}", +# "a photo of the weird {}", +# "a photo of the large {}", +# "a photo of a cool {}", +# "a photo of a small {}", +# ] + imagenet_templates_small = [ - "a photo of a {}", - "a rendering of a {}", - "a cropped photo of the {}", - "the photo of a {}", - "a photo of a clean {}", - "a photo of a dirty {}", - "a dark photo of the {}", - "a photo of my {}", - "a photo of the cool {}", - "a close-up photo of a {}", - "a bright photo of the {}", - "a cropped photo of a {}", - "a photo of the {}", - "a good photo of the {}", - "a photo of one {}", - "a close-up photo of the {}", - "a rendition of the {}", - "a photo of the clean {}", - "a rendition of a {}", - "a photo of a nice {}", - "a good photo of a {}", - "a photo of the nice {}", - "a photo of the small {}", - "a photo of the weird {}", - "a photo of the large {}", - "a photo of a cool {}", - "a photo of a small {}", + "{}", ] imagenet_style_templates_small = [ @@ -213,7 +220,12 @@ def train(args): trainable_params = text_encoder.get_input_embeddings().parameters() # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - optimizer = optimizer_class(trainable_params, lr=args.learning_rate) + # optimizer = optimizer_class(trainable_params, lr=args.learning_rate) + print('enable dadapation.') + optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0) + # optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6) + # optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-6) + # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -227,8 +239,20 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + # lr_scheduler = diffusers.optimization.get_scheduler( + # args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + + # For Adam + lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, + lr_lambda=[lambda epoch: 1], + last_epoch=-1, + verbose=False) + + # For SGD optim + # lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, + # lr_lambda=[lambda epoch: 1], + # last_epoch=-1, + # verbose=False) # acceleratorがなんかよろしくやってくれるらしい text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -366,12 +390,16 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + #logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + + avr_loss = loss_total / (step+1) + logs = {"loss": avr_loss, "dlr0": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} accelerator.log(logs, step=global_step) loss_total += current_loss avr_loss = loss_total / (step+1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + # logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"loss": avr_loss, "dlr0": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: