* 2023/02/16 (v20.7.3)
- Noise offset is recorded to the metadata. Thanks to space-nuko! - Show the moving average loss to prevent loss jumping in `train_network.py` and `train_db.py`. Thanks to shirayu!
This commit is contained in:
parent
f9863e3950
commit
674ed88d13
@ -143,6 +143,9 @@ Then redo the installation instruction within the kohya_ss venv.
|
||||
|
||||
## Change history
|
||||
|
||||
* 2023/02/16 (v20.7.3)
|
||||
- Noise offset is recorded to the metadata. Thanks to space-nuko!
|
||||
- Show the moving average loss to prevent loss jumping in `train_network.py` and `train_db.py`. Thanks to shirayu!
|
||||
* 2023/02/11 (v20.7.2):
|
||||
- `lora_interrogator.py` is added in `networks` folder. See `python networks\lora_interrogator.py -h` for usage.
|
||||
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.
|
||||
|
40
train_db.py
40
train_db.py
@ -17,8 +17,6 @@ from diffusers import DDPMScheduler
|
||||
import library.train_util as train_util
|
||||
from library.train_util import DreamBoothDataset
|
||||
|
||||
import torch.optim as optim
|
||||
import dadaptation
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
@ -135,16 +133,13 @@ def train(args):
|
||||
trainable_params = unet.parameters()
|
||||
|
||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||
# optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||
print('enable dadatation.')
|
||||
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0, d0=0.00000001)
|
||||
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
@ -155,14 +150,8 @@ def train(args):
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する
|
||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
# args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||
|
||||
# For Adam
|
||||
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
||||
lr_lambda=[lambda epoch: 1],
|
||||
last_epoch=-1,
|
||||
verbose=False)
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@ -217,6 +206,8 @@ def train(args):
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth")
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset.set_current_epoch(epoch + 1)
|
||||
@ -227,7 +218,6 @@ def train(args):
|
||||
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
|
||||
text_encoder.train()
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# 指定したステップ数でText Encoderの学習を止める
|
||||
if global_step == args.stop_text_encoder_training:
|
||||
@ -244,10 +234,13 @@ def train(args):
|
||||
else:
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
b_size = latents.shape[0]
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
@ -299,21 +292,24 @@ def train(args):
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
# logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
logs = {"loss": current_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']}
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step+1)
|
||||
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
logs = {"avg_loss": avr_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
accelerator.log(logs, step=epoch+1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
@ -1,5 +1,7 @@
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from torch.optim import Optimizer
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from typing import Optional, Union
|
||||
import importlib
|
||||
import argparse
|
||||
@ -19,8 +21,6 @@ from diffusers import DDPMScheduler
|
||||
import library.train_util as train_util
|
||||
from library.train_util import DreamBoothDataset, FineTuningDataset
|
||||
|
||||
import torch.optim as optim
|
||||
import dadaptation
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
@ -156,7 +156,9 @@ def train(args):
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
||||
|
||||
# unnecessary, but work on low-ram device
|
||||
text_encoder.to("cuda")
|
||||
unet.to("cuda")
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
@ -214,15 +216,10 @@ def train(args):
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||
trainable_params = network.prepare_optimizer_params(None, None)
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||
|
||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||
# optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||
print('enable dadatation.')
|
||||
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0)
|
||||
# optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
|
||||
# optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
|
||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@ -237,23 +234,10 @@ def train(args):
|
||||
|
||||
# lr schedulerを用意する
|
||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
# lr_scheduler = get_scheduler_fix(
|
||||
# args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
# num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
# num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
# override lr_scheduler.
|
||||
|
||||
# For Adam
|
||||
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
||||
lr_lambda=[lambda epoch: 0.25, lambda epoch: 1],
|
||||
last_epoch=-1,
|
||||
verbose=False)
|
||||
|
||||
# For SGD optim
|
||||
# lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
|
||||
# lr_lambda=[lambda epoch: 1, lambda epoch: 0.5],
|
||||
# last_epoch=-1,
|
||||
# verbose=False)
|
||||
lr_scheduler = get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@ -278,17 +262,26 @@ def train(args):
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device)
|
||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
text_encoder.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
if type(text_encoder) == DDP:
|
||||
text_encoder.module.text_model.embeddings.requires_grad_(True)
|
||||
else:
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
else:
|
||||
unet.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
# support DistributedDataParallel
|
||||
if type(text_encoder) == DDP:
|
||||
text_encoder = text_encoder.module
|
||||
unet = unet.module
|
||||
network = network.module
|
||||
|
||||
network.prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if not cache_latents:
|
||||
@ -360,11 +353,13 @@ def train(args):
|
||||
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
||||
"ss_seed": args.seed,
|
||||
"ss_keep_tokens": args.keep_tokens,
|
||||
"ss_noise_offset": args.noise_offset,
|
||||
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
||||
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
||||
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
||||
"ss_training_comment": args.training_comment # will not be updated after training
|
||||
"ss_training_comment": args.training_comment, # will not be updated after training
|
||||
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
|
||||
}
|
||||
|
||||
# uncomment if another network is added
|
||||
@ -398,6 +393,8 @@ def train(args):
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("network_train")
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset.set_current_epoch(epoch + 1)
|
||||
@ -406,7 +403,6 @@ def train(args):
|
||||
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(network):
|
||||
with torch.no_grad():
|
||||
@ -425,6 +421,9 @@ def train(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
@ -435,7 +434,8 @@ def train(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
with autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
@ -466,23 +466,25 @@ def train(args):
|
||||
global_step += 1
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step+1)
|
||||
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
# progress_bar.set_postfix(**logs)
|
||||
logs_str = f"loss: {avr_loss:.3f}, dlr0: {optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']:.2e}, dlr1: {optimizer.param_groups[1]['d']*optimizer.param_groups[1]['lr']:.2e}"
|
||||
progress_bar.set_postfix_str(logs_str)
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||
logs['lr/d*lr'] = optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
accelerator.log(logs, step=epoch+1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@ -568,4 +570,4 @@ if __name__ == '__main__':
|
||||
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
train(args)
|
||||
|
Loading…
Reference in New Issue
Block a user