1st implementation

This commit is contained in:
bmaltais 2023-02-13 21:20:09 -05:00
parent 261b6790ee
commit 6129c7dd40
2 changed files with 26 additions and 9 deletions

View File

@ -22,5 +22,7 @@ fairscale==0.4.13
tensorflow==2.10.1 tensorflow==2.10.1
huggingface-hub==0.12.0 huggingface-hub==0.12.0
xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
# for dadaptation
dadaptation
# for kohya_ss library # for kohya_ss library
. .

View File

@ -19,6 +19,8 @@ from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
from library.train_util import DreamBoothDataset, FineTuningDataset from library.train_util import DreamBoothDataset, FineTuningDataset
import torch.optim as optim
import dadaptation
def collate_fn(examples): def collate_fn(examples):
return examples[0] return examples[0]
@ -212,10 +214,15 @@ def train(args):
else: else:
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) # trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
trainable_params = network.prepare_optimizer_params(None, None)
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class(trainable_params, lr=args.learning_rate) # optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
print('enable dadatation.')
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0)
# optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
# optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-8,)
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
@ -230,10 +237,15 @@ def train(args):
# lr schedulerを用意する # lr schedulerを用意する
# lr_scheduler = diffusers.optimization.get_scheduler( # lr_scheduler = diffusers.optimization.get_scheduler(
lr_scheduler = get_scheduler_fix( # lr_scheduler = get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, # args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, # num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# override lr_scheduler.
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
lr_lambda=[lambda epoch: 0.5, lambda epoch: 1],
last_epoch=-1,
verbose=False)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16: if args.full_fp16:
@ -448,11 +460,14 @@ def train(args):
current_loss = loss.detach().item() current_loss = loss.detach().item()
loss_total += current_loss loss_total += current_loss
avr_loss = loss_total / (step+1) avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} # logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) # progress_bar.set_postfix(**logs)
logs_str = f"loss: {avr_loss:.3f}, dlr: {optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']:.2e}"
progress_bar.set_postfix_str(logs_str)
if args.logging_dir is not None: if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
logs['lr/d*lr'] = optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
@ -545,4 +560,4 @@ if __name__ == '__main__':
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)