1st implementation
This commit is contained in:
parent
261b6790ee
commit
6129c7dd40
@ -22,5 +22,7 @@ fairscale==0.4.13
|
|||||||
tensorflow==2.10.1
|
tensorflow==2.10.1
|
||||||
huggingface-hub==0.12.0
|
huggingface-hub==0.12.0
|
||||||
xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||||
|
# for dadaptation
|
||||||
|
dadaptation
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
.
|
.
|
@ -19,6 +19,8 @@ from diffusers import DDPMScheduler
|
|||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.train_util import DreamBoothDataset, FineTuningDataset
|
from library.train_util import DreamBoothDataset, FineTuningDataset
|
||||||
|
|
||||||
|
import torch.optim as optim
|
||||||
|
import dadaptation
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
return examples[0]
|
return examples[0]
|
||||||
@ -212,10 +214,15 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
optimizer_class = torch.optim.AdamW
|
optimizer_class = torch.optim.AdamW
|
||||||
|
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
# trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||||
|
trainable_params = network.prepare_optimizer_params(None, None)
|
||||||
|
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
# optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||||
|
print('enable dadatation.')
|
||||||
|
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0)
|
||||||
|
# optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
|
||||||
|
# optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-8,)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@ -230,10 +237,15 @@ def train(args):
|
|||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
lr_scheduler = get_scheduler_fix(
|
# lr_scheduler = get_scheduler_fix(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
# args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
# num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
# num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||||
|
# override lr_scheduler.
|
||||||
|
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
||||||
|
lr_lambda=[lambda epoch: 0.5, lambda epoch: 1],
|
||||||
|
last_epoch=-1,
|
||||||
|
verbose=False)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@ -448,11 +460,14 @@ def train(args):
|
|||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
# progress_bar.set_postfix(**logs)
|
||||||
|
logs_str = f"loss: {avr_loss:.3f}, dlr: {optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']:.2e}"
|
||||||
|
progress_bar.set_postfix_str(logs_str)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||||
|
logs['lr/d*lr'] = optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
@ -545,4 +560,4 @@ if __name__ == '__main__':
|
|||||||
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
Loading…
Reference in New Issue
Block a user