2023-01-09 12:47:07 +00:00
# DreamBooth training
# XXX dropped option: fine_tune
2022-11-07 23:40:34 +00:00
2022-12-05 15:49:02 +00:00
import gc
2022-11-19 13:49:42 +00:00
import time
2022-11-07 23:40:34 +00:00
import argparse
import itertools
import math
import os
from tqdm import tqdm
import torch
from accelerate . utils import set_seed
import diffusers
2023-01-09 12:47:07 +00:00
from diffusers import DDPMScheduler
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
import library . train_util as train_util
from library . train_util import DreamBoothDataset
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
def collate_fn ( examples ) :
return examples [ 0 ]
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
def train ( args ) :
train_util . verify_training_args ( args )
train_util . prepare_dataset_args ( args , False )
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
cache_latents = args . cache_latents
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
if args . seed is not None :
set_seed ( args . seed ) # 乱数系列を初期化する
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
tokenizer = train_util . load_tokenizer ( args )
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
train_dataset = DreamBoothDataset ( args . train_batch_size , args . train_data_dir , args . reg_data_dir ,
tokenizer , args . max_token_length , args . caption_extension , args . shuffle_caption , args . keep_tokens ,
args . resolution , args . enable_bucket , args . min_bucket_reso , args . max_bucket_reso , args . prior_loss_weight ,
args . flip_aug , args . color_aug , args . face_crop_aug_range , args . random_crop , args . debug_dataset )
if args . no_token_padding :
train_dataset . disable_token_padding ( )
train_dataset . make_buckets ( )
2022-12-05 15:49:02 +00:00
2023-01-09 12:47:07 +00:00
if args . debug_dataset :
train_util . debug_dataset ( train_dataset )
return
2022-12-05 15:49:02 +00:00
2023-01-09 12:47:07 +00:00
# acceleratorを準備する
print ( " prepare accelerator " )
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
if args . gradient_accumulation_steps > 1 :
print ( f " gradient_accumulation_steps is { args . gradient_accumulation_steps } . accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong " )
print (
f " gradient_accumulation_stepsが { args . gradient_accumulation_steps } に設定されています。accelerateは複数モデル( U-NetおよびText Encoder) の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です " )
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
accelerator , unwrap_model = train_util . prepare_accelerator ( args )
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype , save_dtype = train_util . prepare_dtype ( args )
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
# モデルを読み込む
text_encoder , vae , unet , load_stable_diffusion_format = train_util . load_target_model ( args , weight_dtype )
2022-12-18 01:36:31 +00:00
2023-01-09 12:47:07 +00:00
# verify load/save model formats
2022-12-18 01:36:31 +00:00
if load_stable_diffusion_format :
src_stable_diffusion_ckpt = args . pretrained_model_name_or_path
src_diffusers_model_path = None
else :
src_stable_diffusion_ckpt = None
src_diffusers_model_path = args . pretrained_model_name_or_path
2023-01-09 12:47:07 +00:00
2022-12-18 01:36:31 +00:00
if args . save_model_as is None :
save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args . use_safetensors
else :
save_stable_diffusion_format = args . save_model_as . lower ( ) == ' ckpt ' or args . save_model_as . lower ( ) == ' safetensors '
use_safetensors = args . use_safetensors or ( " safetensors " in args . save_model_as . lower ( ) )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# モデルに xformers とか memory efficient attention を組み込む
2023-01-09 12:47:07 +00:00
train_util . replace_unet_modules ( unet , args . mem_eff_attn , args . xformers )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# 学習を準備する
if cache_latents :
vae . to ( accelerator . device , dtype = weight_dtype )
vae . requires_grad_ ( False )
vae . eval ( )
with torch . no_grad ( ) :
2023-01-09 12:47:07 +00:00
train_dataset . cache_latents ( vae )
2022-12-05 15:49:02 +00:00
vae . to ( " cpu " )
2022-11-27 14:57:07 +00:00
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
2022-12-05 15:49:02 +00:00
gc . collect ( )
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
# 学習を準備する:モデルを適切な状態にする
if args . stop_text_encoder_training is None :
args . stop_text_encoder_training = args . max_train_steps + 1 # do not stop until end
train_text_encoder = args . stop_text_encoder_training > = 0
2022-11-27 14:57:07 +00:00
unet . requires_grad_ ( True ) # 念のため追加
2023-01-09 12:47:07 +00:00
text_encoder . requires_grad_ ( train_text_encoder )
if not train_text_encoder :
print ( " Text Encoder is not trained. " )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
if args . gradient_checkpointing :
unet . enable_gradient_checkpointing ( )
text_encoder . gradient_checkpointing_enable ( )
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
if not cache_latents :
vae . requires_grad_ ( False )
vae . eval ( )
vae . to ( accelerator . device , dtype = weight_dtype )
2022-11-27 14:57:07 +00:00
# 学習に必要なクラスを準備する
print ( " prepare optimizer, data loader etc. " )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# 8-bit Adamを使う
if args . use_8bit_adam :
try :
import bitsandbytes as bnb
except ImportError :
raise ImportError ( " No bitsand bytes / bitsandbytesがインストールされていないようです " )
print ( " use 8-bit Adam optimizer " )
optimizer_class = bnb . optim . AdamW8bit
else :
optimizer_class = torch . optim . AdamW
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
if train_text_encoder :
trainable_params = ( itertools . chain ( unet . parameters ( ) , text_encoder . parameters ( ) ) )
else :
trainable_params = unet . parameters ( )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class ( trainable_params , lr = args . learning_rate )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# dataloaderを準備する
# DataLoaderのプロセス数: 0はメインプロセスになる
2023-01-15 16:05:22 +00:00
n_workers = min ( args . max_data_loader_n_workers , os . cpu_count ( ) - 1 ) # cpu_count-1 ただし最大で指定された数まで
2022-11-27 14:57:07 +00:00
train_dataloader = torch . utils . data . DataLoader (
train_dataset , batch_size = 1 , shuffle = False , collate_fn = collate_fn , num_workers = n_workers )
2022-11-07 23:40:34 +00:00
2023-01-15 16:05:22 +00:00
# 学習ステップ数を計算する
if args . max_train_epochs is not None :
args . max_train_steps = args . max_train_epochs * len ( train_dataloader )
print ( f " override steps. steps for { args . max_train_epochs } epochs is / 指定エポックまでのステップ数: { args . max_train_steps } " )
2022-11-27 14:57:07 +00:00
# lr schedulerを用意する
lr_scheduler = diffusers . optimization . get_scheduler (
args . lr_scheduler , optimizer , num_warmup_steps = args . lr_warmup_steps , num_training_steps = args . max_train_steps )
2022-11-07 23:40:34 +00:00
2022-12-13 18:49:14 +00:00
# 実験的機能: 勾配も含めたfp16学習を行う モデル全体をfp16にする
if args . full_fp16 :
assert args . mixed_precision == " fp16 " , " full_fp16 requires mixed precision= ' fp16 ' / full_fp16を使う場合はmixed_precision= ' fp16 ' を指定してください。 "
print ( " enable full fp16 training. " )
unet . to ( weight_dtype )
text_encoder . to ( weight_dtype )
2022-11-27 14:57:07 +00:00
# acceleratorがなんかよろしくやってくれるらしい
2023-01-09 12:47:07 +00:00
if train_text_encoder :
unet , text_encoder , optimizer , train_dataloader , lr_scheduler = accelerator . prepare (
unet , text_encoder , optimizer , train_dataloader , lr_scheduler )
else :
unet , optimizer , train_dataloader , lr_scheduler = accelerator . prepare ( unet , optimizer , train_dataloader , lr_scheduler )
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
if not train_text_encoder :
text_encoder . to ( accelerator . device , dtype = weight_dtype ) # to avoid 'cpu' vs 'cuda' error
2022-11-07 23:40:34 +00:00
2022-12-13 18:49:14 +00:00
# 実験的機能: 勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args . full_fp16 :
2023-01-09 12:47:07 +00:00
train_util . patch_accelerator_for_fp16_training ( accelerator )
2022-12-13 18:49:14 +00:00
2022-11-27 14:57:07 +00:00
# resumeする
if args . resume is not None :
print ( f " resume training from state: { args . resume } " )
accelerator . load_state ( args . resume )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# epoch数を計算する
2023-01-09 12:47:07 +00:00
num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
num_train_epochs = math . ceil ( args . max_train_steps / num_update_steps_per_epoch )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# 学習する
2023-01-09 12:47:07 +00:00
total_batch_size = args . train_batch_size * accelerator . num_processes * args . gradient_accumulation_steps
2022-11-27 14:57:07 +00:00
print ( " running training / 学習開始 " )
print ( f " num train images * repeats / 学習画像の数×繰り返し回数: { train_dataset . num_train_images } " )
print ( f " num reg images / 正則化画像の数: { train_dataset . num_reg_images } " )
print ( f " num batches per epoch / 1epochのバッチ数: { len ( train_dataloader ) } " )
print ( f " num epochs / epoch数: { num_train_epochs } " )
print ( f " batch size per device / バッチサイズ: { args . train_batch_size } " )
2023-01-09 12:47:07 +00:00
print ( f " total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): { total_batch_size } " )
print ( f " gradient ccumulation steps / 勾配を合計するステップ数 = { args . gradient_accumulation_steps } " )
2022-11-27 14:57:07 +00:00
print ( f " total optimization steps / 学習ステップ数: { args . max_train_steps } " )
2022-11-07 23:40:34 +00:00
2022-11-29 17:47:48 +00:00
progress_bar = tqdm ( range ( args . max_train_steps ) , smoothing = 0 , disable = not accelerator . is_local_main_process , desc = " steps " )
2022-11-27 14:57:07 +00:00
global_step = 0
2022-11-07 23:40:34 +00:00
2022-11-29 17:47:48 +00:00
noise_scheduler = DDPMScheduler ( beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = " scaled_linear " ,
num_train_timesteps = 1000 , clip_sample = False )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
if accelerator . is_main_process :
accelerator . init_trackers ( " dreambooth " )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
for epoch in range ( num_train_epochs ) :
print ( f " epoch { epoch + 1 } / { num_train_epochs } " )
2022-11-30 12:31:52 +00:00
# 指定したステップ数までText Encoderを学習する: epoch最初の状態
2022-11-27 14:57:07 +00:00
unet . train ( )
2023-01-09 12:47:07 +00:00
# train==True is required to enable gradient_checkpointing
if args . gradient_checkpointing or global_step < args . stop_text_encoder_training :
2022-11-30 12:31:52 +00:00
text_encoder . train ( )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
loss_total = 0
for step , batch in enumerate ( train_dataloader ) :
2022-11-30 12:31:52 +00:00
# 指定したステップ数でText Encoderの学習を止める
2023-01-09 12:47:07 +00:00
if global_step == args . stop_text_encoder_training :
2022-11-30 12:31:52 +00:00
print ( f " stop text encoder training at step { global_step } " )
2023-01-09 12:47:07 +00:00
if not args . gradient_checkpointing :
text_encoder . train ( False )
2022-12-23 12:56:35 +00:00
text_encoder . requires_grad_ ( False )
2022-11-30 12:31:52 +00:00
2022-11-27 14:57:07 +00:00
with accelerator . accumulate ( unet ) :
with torch . no_grad ( ) :
# latentに変換
if cache_latents :
latents = batch [ " latents " ] . to ( accelerator . device )
else :
latents = vae . encode ( batch [ " images " ] . to ( dtype = weight_dtype ) ) . latent_dist . sample ( )
latents = latents * 0.18215
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# Sample noise that we'll add to the latents
noise = torch . randn_like ( latents , device = latents . device )
b_size = latents . shape [ 0 ]
2022-11-07 23:40:34 +00:00
2023-01-09 12:47:07 +00:00
# Get the text embedding for conditioning
with torch . set_grad_enabled ( global_step < args . stop_text_encoder_training ) :
input_ids = batch [ " input_ids " ] . to ( accelerator . device )
encoder_hidden_states = train_util . get_hidden_states (
args , input_ids , tokenizer , text_encoder , None if not args . full_fp16 else weight_dtype )
2022-11-27 14:57:07 +00:00
# Sample a random timestep for each image
timesteps = torch . randint ( 0 , noise_scheduler . config . num_train_timesteps , ( b_size , ) , device = latents . device )
timesteps = timesteps . long ( )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
# Predict the noise residual
noise_pred = unet ( noisy_latents , timesteps , encoder_hidden_states ) . sample
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
if args . v_parameterization :
# v-parameterization training
2022-12-13 18:49:14 +00:00
target = noise_scheduler . get_velocity ( latents , noise , timesteps )
2022-11-27 14:57:07 +00:00
else :
target = noise
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
loss = torch . nn . functional . mse_loss ( noise_pred . float ( ) , target . float ( ) , reduction = " none " )
loss = loss . mean ( [ 1 , 2 , 3 ] )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
loss_weights = batch [ " loss_weights " ] # 各sampleごとのweight
loss = loss * loss_weights
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
loss = loss . mean ( ) # 平均なのでbatch_sizeで割る必要なし
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
accelerator . backward ( loss )
if accelerator . sync_gradients :
2023-01-09 12:47:07 +00:00
if train_text_encoder :
params_to_clip = ( itertools . chain ( unet . parameters ( ) , text_encoder . parameters ( ) ) )
else :
params_to_clip = unet . parameters ( )
2022-11-27 14:57:07 +00:00
accelerator . clip_grad_norm_ ( params_to_clip , 1.0 ) # args.max_grad_norm)
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
optimizer . step ( )
lr_scheduler . step ( )
optimizer . zero_grad ( set_to_none = True )
2022-11-10 01:48:27 +00:00
2022-11-27 14:57:07 +00:00
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator . sync_gradients :
progress_bar . update ( 1 )
global_step + = 1
2022-11-10 01:48:27 +00:00
2022-11-27 14:57:07 +00:00
current_loss = loss . detach ( ) . item ( )
if args . logging_dir is not None :
logs = { " loss " : current_loss , " lr " : lr_scheduler . get_last_lr ( ) [ 0 ] }
accelerator . log ( logs , step = global_step )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
loss_total + = current_loss
avr_loss = loss_total / ( step + 1 )
logs = { " loss " : avr_loss } # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar . set_postfix ( * * logs )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
if global_step > = args . max_train_steps :
break
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
if args . logging_dir is not None :
logs = { " epoch_loss " : loss_total / len ( train_dataloader ) }
accelerator . log ( logs , step = epoch + 1 )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
accelerator . wait_for_everyone ( )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
if args . save_every_n_epochs is not None :
2023-01-09 12:47:07 +00:00
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util . save_sd_model_on_epoch_end ( args , accelerator , src_path , save_stable_diffusion_format , use_safetensors ,
save_dtype , epoch , num_train_epochs , global_step , unwrap_model ( text_encoder ) , unwrap_model ( unet ) , vae )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
is_main_process = accelerator . is_main_process
if is_main_process :
2022-12-13 18:49:14 +00:00
unet = unwrap_model ( unet )
text_encoder = unwrap_model ( text_encoder )
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
accelerator . end_training ( )
2022-11-10 01:48:27 +00:00
2022-11-27 14:57:07 +00:00
if args . save_state :
2023-01-09 12:47:07 +00:00
train_util . save_state_on_train_end ( args , accelerator )
2022-11-10 01:48:27 +00:00
2022-11-27 14:57:07 +00:00
del accelerator # この後メモリを使うのでこれは消す
2022-11-07 23:40:34 +00:00
2022-11-27 14:57:07 +00:00
if is_main_process :
2023-01-09 12:47:07 +00:00
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util . save_sd_model_on_train_end ( args , src_path , save_stable_diffusion_format , use_safetensors ,
save_dtype , epoch , global_step , text_encoder , unet , vae )
2022-11-27 14:57:07 +00:00
print ( " model saved. " )
2022-11-07 23:40:34 +00:00
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( )
2023-01-09 12:47:07 +00:00
train_util . add_sd_models_arguments ( parser )
train_util . add_dataset_arguments ( parser , True , False )
train_util . add_training_arguments ( parser , True )
train_util . add_sd_saving_arguments ( parser )
2022-11-07 23:40:34 +00:00
parser . add_argument ( " --no_token_padding " , action = " store_true " ,
help = " disable token padding (same as Diffuser ' s DreamBooth) / トークンのpaddingを無効にする( Diffusers版DreamBoothと同じ動作) " )
2022-12-05 15:49:02 +00:00
parser . add_argument ( " --stop_text_encoder_training " , type = int , default = None ,
2023-01-09 12:47:07 +00:00
help = " steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない " )
2022-11-27 14:57:07 +00:00
2022-11-07 23:40:34 +00:00
args = parser . parse_args ( )
2023-01-09 12:47:07 +00:00
train ( args )