diff --git a/README.md b/README.md index d5cb7b8..4d36042 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,9 @@ Then redo the installation instruction within the kohya_ss venv. ## Change history +* 2023/02/16 (v20.7.3) + - Noise offset is recorded to the metadata. Thanks to space-nuko! + - Show the moving average loss to prevent loss jumping in `train_network.py` and `train_db.py`. Thanks to shirayu! * 2023/02/11 (v20.7.2): - `lora_interrogator.py` is added in `networks` folder. See `python networks\lora_interrogator.py -h` for usage. - For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate. diff --git a/train_db.py b/train_db.py index aacfcc8..e4f1e54 100644 --- a/train_db.py +++ b/train_db.py @@ -17,8 +17,6 @@ from diffusers import DDPMScheduler import library.train_util as train_util from library.train_util import DreamBoothDataset -import torch.optim as optim -import dadaptation def collate_fn(examples): return examples[0] @@ -135,16 +133,13 @@ def train(args): trainable_params = unet.parameters() # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - # optimizer = optimizer_class(trainable_params, lr=args.learning_rate) + optimizer = optimizer_class(trainable_params, lr=args.learning_rate) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) - print('enable dadatation.') - optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0, d0=0.00000001) - # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -155,14 +150,8 @@ def train(args): args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end # lr schedulerを用意する - # lr_scheduler = diffusers.optimization.get_scheduler( - # args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) - - # For Adam - lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, - lr_lambda=[lambda epoch: 1], - last_epoch=-1, - verbose=False) + lr_scheduler = diffusers.optimization.get_scheduler( + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -217,6 +206,8 @@ def train(args): if accelerator.is_main_process: accelerator.init_trackers("dreambooth") + loss_list = [] + loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset.set_current_epoch(epoch + 1) @@ -227,7 +218,6 @@ def train(args): if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: text_encoder.train() - loss_total = 0 for step, batch in enumerate(train_dataloader): # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: @@ -244,10 +234,13 @@ def train(args): else: latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 + b_size = latents.shape[0] # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): @@ -299,21 +292,24 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: - # logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} - logs = {"loss": current_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} + logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} accelerator.log(logs, step=global_step) + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss loss_total += current_loss - avr_loss = loss_total / (step+1) - # logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - logs = {"avg_loss": avr_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} # , "lr": lr_scheduler.get_last_lr()[0]} + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_total / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index 824c298..5983a7e 100644 --- a/train_network.py +++ b/train_network.py @@ -1,5 +1,7 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from torch.optim import Optimizer +from torch.cuda.amp import autocast +from torch.nn.parallel import DistributedDataParallel as DDP from typing import Optional, Union import importlib import argparse @@ -19,8 +21,6 @@ from diffusers import DDPMScheduler import library.train_util as train_util from library.train_util import DreamBoothDataset, FineTuningDataset -import torch.optim as optim -import dadaptation def collate_fn(examples): return examples[0] @@ -156,7 +156,9 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - + # unnecessary, but work on low-ram device + text_encoder.to("cuda") + unet.to("cuda") # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -214,15 +216,10 @@ def train(args): else: optimizer_class = torch.optim.AdamW - # trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) - trainable_params = network.prepare_optimizer_params(None, None) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - # optimizer = optimizer_class(trainable_params, lr=args.learning_rate) - print('enable dadatation.') - optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0) - # optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6) - # optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-6) + optimizer = optimizer_class(trainable_params, lr=args.learning_rate) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -237,23 +234,10 @@ def train(args): # lr schedulerを用意する # lr_scheduler = diffusers.optimization.get_scheduler( - # lr_scheduler = get_scheduler_fix( - # args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - # num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - # num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) - # override lr_scheduler. - - # For Adam - lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, - lr_lambda=[lambda epoch: 0.25, lambda epoch: 1], - last_epoch=-1, - verbose=False) - - # For SGD optim - # lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, - # lr_lambda=[lambda epoch: 1, lambda epoch: 0.5], - # last_epoch=-1, - # verbose=False) + lr_scheduler = get_scheduler_fix( + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -278,17 +262,26 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() text_encoder.train() # set top parameter requires_grad = True for gradient checkpointing works - text_encoder.text_model.embeddings.requires_grad_(True) + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) + else: + text_encoder.text_model.embeddings.requires_grad_(True) else: unet.eval() text_encoder.eval() + # support DistributedDataParallel + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module + network.prepare_grad_etc(text_encoder, unet) if not cache_latents: @@ -360,11 +353,13 @@ def train(args): "ss_max_bucket_reso": train_dataset.max_bucket_reso, "ss_seed": args.seed, "ss_keep_tokens": args.keep_tokens, + "ss_noise_offset": args.noise_offset, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), "ss_bucket_info": json.dumps(train_dataset.bucket_info), - "ss_training_comment": args.training_comment # will not be updated after training + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash() } # uncomment if another network is added @@ -398,6 +393,8 @@ def train(args): if accelerator.is_main_process: accelerator.init_trackers("network_train") + loss_list = [] + loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset.set_current_epoch(epoch + 1) @@ -406,7 +403,6 @@ def train(args): network.on_epoch_start(text_encoder, unet) - loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): with torch.no_grad(): @@ -425,6 +421,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) @@ -435,7 +434,8 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training @@ -466,23 +466,25 @@ def train(args): global_step += 1 current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss loss_total += current_loss - avr_loss = loss_total / (step+1) - # logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - # progress_bar.set_postfix(**logs) - logs_str = f"loss: {avr_loss:.3f}, dlr0: {optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']:.2e}, dlr1: {optimizer.param_groups[1]['d']*optimizer.param_groups[1]['lr']:.2e}" - progress_bar.set_postfix_str(logs_str) + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) if args.logging_dir is not None: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - logs['lr/d*lr'] = optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr'] accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_total / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() @@ -568,4 +570,4 @@ if __name__ == '__main__': help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") args = parser.parse_args() - train(args) \ No newline at end of file + train(args)