2023-02-03 19:40:03 +00:00
from diffusers . optimization import SchedulerType , TYPE_TO_SCHEDULER_FUNCTION
from torch . optim import Optimizer
2023-02-14 23:52:08 +00:00
from torch . cuda . amp import autocast
from torch . nn . parallel import DistributedDataParallel as DDP
2023-02-03 19:40:03 +00:00
from typing import Optional , Union
2022-12-26 13:47:33 +00:00
import importlib
import argparse
2023-01-09 12:47:07 +00:00
import gc
2022-12-26 13:47:33 +00:00
import math
import os
2023-01-22 15:18:00 +00:00
import random
import time
import json
2022-12-26 13:47:33 +00:00
from tqdm import tqdm
import torch
from accelerate . utils import set_seed
import diffusers
2023-01-09 12:47:07 +00:00
from diffusers import DDPMScheduler
2022-12-26 13:47:33 +00:00
2023-01-09 12:47:07 +00:00
import library . train_util as train_util
from library . train_util import DreamBoothDataset , FineTuningDataset
2022-12-26 13:47:33 +00:00
def collate_fn ( examples ) :
return examples [ 0 ]
2023-01-22 15:18:00 +00:00
def generate_step_logs ( args : argparse . Namespace , current_loss , avr_loss , lr_scheduler ) :
logs = { " loss/current " : current_loss , " loss/average " : avr_loss }
if args . network_train_unet_only :
logs [ " lr/unet " ] = lr_scheduler . get_last_lr ( ) [ 0 ]
elif args . network_train_text_encoder_only :
logs [ " lr/textencoder " ] = lr_scheduler . get_last_lr ( ) [ 0 ]
else :
logs [ " lr/textencoder " ] = lr_scheduler . get_last_lr ( ) [ 0 ]
logs [ " lr/unet " ] = lr_scheduler . get_last_lr ( ) [ - 1 ] # may be same to textencoder
return logs
2023-01-29 16:10:06 +00:00
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix (
name : Union [ str , SchedulerType ] ,
optimizer : Optimizer ,
num_warmup_steps : Optional [ int ] = None ,
num_training_steps : Optional [ int ] = None ,
num_cycles : int = 1 ,
power : float = 1.0 ,
) :
2023-02-03 19:40:03 +00:00
"""
Unified API to get any scheduler from its name .
Args :
name ( ` str ` or ` SchedulerType ` ) :
The name of the scheduler to use .
optimizer ( ` torch . optim . Optimizer ` ) :
The optimizer that will be used during training .
num_warmup_steps ( ` int ` , * optional * ) :
The number of warmup steps to do . This is not required by all schedulers ( hence the argument being
optional ) , the function will raise an error if it ' s unset and the scheduler type requires it.
num_training_steps ( ` int ` ` , * optional * ) :
The number of training steps to do . This is not required by all schedulers ( hence the argument being
optional ) , the function will raise an error if it ' s unset and the scheduler type requires it.
num_cycles ( ` int ` , * optional * ) :
The number of hard restarts used in ` COSINE_WITH_RESTARTS ` scheduler .
power ( ` float ` , * optional * , defaults to 1.0 ) :
Power factor . See ` POLYNOMIAL ` scheduler
last_epoch ( ` int ` , * optional * , defaults to - 1 ) :
The index of the last epoch when resuming training .
"""
name = SchedulerType ( name )
schedule_func = TYPE_TO_SCHEDULER_FUNCTION [ name ]
if name == SchedulerType . CONSTANT :
return schedule_func ( optimizer )
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None :
raise ValueError ( f " { name } requires `num_warmup_steps`, please provide that argument. " )
if name == SchedulerType . CONSTANT_WITH_WARMUP :
return schedule_func ( optimizer , num_warmup_steps = num_warmup_steps )
# All other schedulers require `num_training_steps`
if num_training_steps is None :
raise ValueError ( f " { name } requires `num_training_steps`, please provide that argument. " )
if name == SchedulerType . COSINE_WITH_RESTARTS :
return schedule_func (
optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , num_cycles = num_cycles
)
if name == SchedulerType . POLYNOMIAL :
return schedule_func (
optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , power = power
)
return schedule_func ( optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps )
2023-01-29 16:10:06 +00:00
2022-12-26 13:47:33 +00:00
def train ( args ) :
2023-01-22 15:18:00 +00:00
session_id = random . randint ( 0 , 2 * * 32 )
training_started_at = time . time ( )
2023-01-09 12:47:07 +00:00
train_util . verify_training_args ( args )
train_util . prepare_dataset_args ( args , True )
2022-12-26 13:47:33 +00:00
2023-01-09 12:47:07 +00:00
cache_latents = args . cache_latents
2022-12-26 13:47:33 +00:00
use_dreambooth_method = args . in_json is None
if args . seed is not None :
set_seed ( args . seed )
2023-01-09 12:47:07 +00:00
tokenizer = train_util . load_tokenizer ( args )
2022-12-26 13:47:33 +00:00
# データセットを準備する
if use_dreambooth_method :
print ( " Use DreamBooth method. " )
train_dataset = DreamBoothDataset ( args . train_batch_size , args . train_data_dir , args . reg_data_dir ,
tokenizer , args . max_token_length , args . caption_extension , args . shuffle_caption , args . keep_tokens ,
2023-02-08 01:58:35 +00:00
args . resolution , args . enable_bucket , args . min_bucket_reso , args . max_bucket_reso ,
args . bucket_reso_steps , args . bucket_no_upscale ,
args . prior_loss_weight , args . flip_aug , args . color_aug , args . face_crop_aug_range ,
2023-02-05 19:16:53 +00:00
args . random_crop , args . debug_dataset )
2022-12-26 13:47:33 +00:00
else :
print ( " Train with captions. " )
2023-01-09 12:47:07 +00:00
train_dataset = FineTuningDataset ( args . in_json , args . train_batch_size , args . train_data_dir ,
2022-12-26 13:47:33 +00:00
tokenizer , args . max_token_length , args . shuffle_caption , args . keep_tokens ,
2023-01-09 12:47:07 +00:00
args . resolution , args . enable_bucket , args . min_bucket_reso , args . max_bucket_reso ,
2023-02-08 01:58:35 +00:00
args . bucket_reso_steps , args . bucket_no_upscale ,
2023-01-09 12:47:07 +00:00
args . flip_aug , args . color_aug , args . face_crop_aug_range , args . random_crop ,
args . dataset_repeats , args . debug_dataset )
2023-02-08 01:58:35 +00:00
# 学習データのdropout率を設定する
2023/02/09 (v20.7.1)
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).
- ``--caption_dropout_every_n_epochs`` option specifies how many epochs to drop captions. If ``3`` is specified, in epoch 3, 6, 9 ..., images are trained with all captions empty. Default is None (no dropout).
- ``--caption_tag_dropout_rate`` option specified the dropout rate for tags (comma separated tokens) (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the tag is removed from the caption. If ``--keep_tokens`` option is set, these tokens (tags) are not dropped. Default is 0 (no droupout).
- The bulk image downsampling script is added. Documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) (in Jpanaese). Thanks to bmaltais!
- Typo check is added. Thanks to shirayu!
- Add option to autolaunch the GUI in a browser and set the server_port. USe either `gui.ps1 --inbrowser --server_port 3456`or `gui.cmd -inbrowser -server_port 3456`
2023-02-10 00:17:24 +00:00
train_dataset . set_caption_dropout ( args . caption_dropout_rate , args . caption_dropout_every_n_epochs , args . caption_tag_dropout_rate )
2023-02-08 01:58:35 +00:00
2023-01-09 12:47:07 +00:00
train_dataset . make_buckets ( )
2022-12-26 13:47:33 +00:00
if args . debug_dataset :
2023-01-09 12:47:07 +00:00
train_util . debug_dataset ( train_dataset )
2022-12-26 13:47:33 +00:00
return
if len ( train_dataset ) == 0 :
2023-02-03 19:40:03 +00:00
print ( " No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください( train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります) " )
2022-12-26 13:47:33 +00:00
return
# acceleratorを準備する
print ( " prepare accelerator " )
2023-01-09 12:47:07 +00:00
accelerator , unwrap_model = train_util . prepare_accelerator ( args )
2022-12-26 13:47:33 +00:00
# mixed precisionに対応した型を用意しておき適宜castする
2023-01-09 12:47:07 +00:00
weight_dtype , save_dtype = train_util . prepare_dtype ( args )
2022-12-26 13:47:33 +00:00
# モデルを読み込む
2023-01-09 12:47:07 +00:00
text_encoder , vae , unet , _ = train_util . load_target_model ( args , weight_dtype )
2023-02-19 11:50:33 +00:00
# work on low-ram device
if args . lowram :
text_encoder . to ( " cuda " )
unet . to ( " cuda " )
2022-12-26 13:47:33 +00:00
# モデルに xformers とか memory efficient attention を組み込む
2023-01-09 12:47:07 +00:00
train_util . replace_unet_modules ( unet , args . mem_eff_attn , args . xformers )
2022-12-26 13:47:33 +00:00
# 学習を準備する
if cache_latents :
vae . to ( accelerator . device , dtype = weight_dtype )
vae . requires_grad_ ( False )
vae . eval ( )
with torch . no_grad ( ) :
train_dataset . cache_latents ( vae )
vae . to ( " cpu " )
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
gc . collect ( )
# prepare network
print ( " import network module: " , args . network_module )
network_module = importlib . import_module ( args . network_module )
net_kwargs = { }
if args . network_args is not None :
for net_arg in args . network_args :
key , value = net_arg . split ( ' = ' )
net_kwargs [ key ] = value
2023-01-22 15:18:00 +00:00
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
network = network_module . create_network ( 1.0 , args . network_dim , args . network_alpha , vae , text_encoder , unet , * * net_kwargs )
2022-12-26 13:47:33 +00:00
if network is None :
return
if args . network_weights is not None :
print ( " load network weights from: " , args . network_weights )
network . load_weights ( args . network_weights )
train_unet = not args . network_train_text_encoder_only
train_text_encoder = not args . network_train_unet_only
network . apply_to ( text_encoder , unet , train_text_encoder , train_unet )
if args . gradient_checkpointing :
unet . enable_gradient_checkpointing ( )
text_encoder . gradient_checkpointing_enable ( )
network . enable_gradient_checkpointing ( ) # may have no effect
# 学習に必要なクラスを準備する
print ( " prepare optimizer, data loader etc. " )
# 8-bit Adamを使う
if args . use_8bit_adam :
try :
import bitsandbytes as bnb
except ImportError :
raise ImportError ( " No bitsand bytes / bitsandbytesがインストールされていないようです " )
print ( " use 8-bit Adam optimizer " )
optimizer_class = bnb . optim . AdamW8bit
2023-02-19 11:50:33 +00:00
elif args . use_lion_optimizer :
try :
import lion_pytorch
except ImportError :
raise ImportError ( " No lion_pytorch / lion_pytorch がインストールされていないようです " )
print ( " use Lion optimizer " )
optimizer_class = lion_pytorch . Lion
2022-12-26 13:47:33 +00:00
else :
optimizer_class = torch . optim . AdamW
2023-02-19 11:50:33 +00:00
optimizer_name = optimizer_class . __module__ + " . " + optimizer_class . __name__
2022-12-26 13:47:33 +00:00
trainable_params = network . prepare_optimizer_params ( args . text_encoder_lr , args . unet_lr )
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class ( trainable_params , lr = args . learning_rate )
# dataloaderを準備する
# DataLoaderのプロセス数: 0はメインプロセスになる
2023-01-15 16:05:22 +00:00
n_workers = min ( args . max_data_loader_n_workers , os . cpu_count ( ) - 1 ) # cpu_count-1 ただし最大で指定された数まで
2022-12-26 13:47:33 +00:00
train_dataloader = torch . utils . data . DataLoader (
2023-02-04 16:55:06 +00:00
train_dataset , batch_size = 1 , shuffle = False , collate_fn = collate_fn , num_workers = n_workers , persistent_workers = args . persistent_data_loader_workers )
2022-12-26 13:47:33 +00:00
2023-01-15 16:05:22 +00:00
# 学習ステップ数を計算する
if args . max_train_epochs is not None :
args . max_train_steps = args . max_train_epochs * len ( train_dataloader )
print ( f " override steps. steps for { args . max_train_epochs } epochs is / 指定エポックまでのステップ数: { args . max_train_steps } " )
2022-12-26 13:47:33 +00:00
# lr schedulerを用意する
2023-01-29 16:10:06 +00:00
# lr_scheduler = diffusers.optimization.get_scheduler(
lr_scheduler = get_scheduler_fix (
2023-02-03 19:40:03 +00:00
args . lr_scheduler , optimizer , num_warmup_steps = args . lr_warmup_steps ,
2023-01-29 16:10:06 +00:00
num_training_steps = args . max_train_steps * args . gradient_accumulation_steps ,
num_cycles = args . lr_scheduler_num_cycles , power = args . lr_scheduler_power )
2022-12-26 13:47:33 +00:00
# 実験的機能: 勾配も含めたfp16学習を行う モデル全体をfp16にする
if args . full_fp16 :
assert args . mixed_precision == " fp16 " , " full_fp16 requires mixed precision= ' fp16 ' / full_fp16を使う場合はmixed_precision= ' fp16 ' を指定してください。 "
print ( " enable full fp16 training. " )
network . to ( weight_dtype )
# acceleratorがなんかよろしくやってくれるらしい
if train_unet and train_text_encoder :
unet , text_encoder , network , optimizer , train_dataloader , lr_scheduler = accelerator . prepare (
unet , text_encoder , network , optimizer , train_dataloader , lr_scheduler )
elif train_unet :
unet , network , optimizer , train_dataloader , lr_scheduler = accelerator . prepare (
unet , network , optimizer , train_dataloader , lr_scheduler )
elif train_text_encoder :
text_encoder , network , optimizer , train_dataloader , lr_scheduler = accelerator . prepare (
text_encoder , network , optimizer , train_dataloader , lr_scheduler )
else :
network , optimizer , train_dataloader , lr_scheduler = accelerator . prepare (
network , optimizer , train_dataloader , lr_scheduler )
unet . requires_grad_ ( False )
unet . to ( accelerator . device , dtype = weight_dtype )
text_encoder . requires_grad_ ( False )
2023-02-14 23:52:08 +00:00
text_encoder . to ( accelerator . device )
2023-01-09 12:47:07 +00:00
if args . gradient_checkpointing : # according to TI example in Diffusers, train is required
unet . train ( )
text_encoder . train ( )
2023-01-19 20:47:43 +00:00
# set top parameter requires_grad = True for gradient checkpointing works
2023-02-14 23:52:08 +00:00
if type ( text_encoder ) == DDP :
text_encoder . module . text_model . embeddings . requires_grad_ ( True )
else :
text_encoder . text_model . embeddings . requires_grad_ ( True )
2023-01-09 12:47:07 +00:00
else :
unet . eval ( )
text_encoder . eval ( )
2022-12-26 13:47:33 +00:00
2023-02-14 23:52:08 +00:00
# support DistributedDataParallel
if type ( text_encoder ) == DDP :
text_encoder = text_encoder . module
unet = unet . module
network = network . module
2022-12-26 13:47:33 +00:00
network . prepare_grad_etc ( text_encoder , unet )
if not cache_latents :
vae . requires_grad_ ( False )
vae . eval ( )
vae . to ( accelerator . device , dtype = weight_dtype )
# 実験的機能: 勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args . full_fp16 :
2023-01-09 12:47:07 +00:00
train_util . patch_accelerator_for_fp16_training ( accelerator )
2022-12-26 13:47:33 +00:00
# resumeする
if args . resume is not None :
print ( f " resume training from state: { args . resume } " )
accelerator . load_state ( args . resume )
# epoch数を計算する
num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
num_train_epochs = math . ceil ( args . max_train_steps / num_update_steps_per_epoch )
2023-01-26 21:22:58 +00:00
if ( args . save_n_epoch_ratio is not None ) and ( args . save_n_epoch_ratio > 0 ) :
args . save_every_n_epochs = math . floor ( num_train_epochs / args . save_n_epoch_ratio ) or 1
2022-12-26 13:47:33 +00:00
# 学習する
total_batch_size = args . train_batch_size * accelerator . num_processes * args . gradient_accumulation_steps
print ( " running training / 学習開始 " )
print ( f " num train images * repeats / 学習画像の数×繰り返し回数: { train_dataset . num_train_images } " )
print ( f " num reg images / 正則化画像の数: { train_dataset . num_reg_images } " )
print ( f " num batches per epoch / 1epochのバッチ数: { len ( train_dataloader ) } " )
print ( f " num epochs / epoch数: { num_train_epochs } " )
print ( f " batch size per device / バッチサイズ: { args . train_batch_size } " )
print ( f " total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): { total_batch_size } " )
2023-01-15 16:05:22 +00:00
print ( f " gradient accumulation steps / 勾配を合計するステップ数 = { args . gradient_accumulation_steps } " )
2022-12-26 13:47:33 +00:00
print ( f " total optimization steps / 学習ステップ数: { args . max_train_steps } " )
2023-01-15 16:05:22 +00:00
metadata = {
2023-01-22 15:18:00 +00:00
" ss_session_id " : session_id , # random integer indicating which group of epochs the model came from
" ss_training_started_at " : training_started_at , # unix timestamp
" ss_output_name " : args . output_name ,
2023-01-15 16:05:22 +00:00
" ss_learning_rate " : args . learning_rate ,
" ss_text_encoder_lr " : args . text_encoder_lr ,
" ss_unet_lr " : args . unet_lr ,
2023-01-22 15:18:00 +00:00
" ss_num_train_images " : train_dataset . num_train_images , # includes repeating
2023-01-15 16:05:22 +00:00
" ss_num_reg_images " : train_dataset . num_reg_images ,
" ss_num_batches_per_epoch " : len ( train_dataloader ) ,
" ss_num_epochs " : num_train_epochs ,
" ss_batch_size_per_device " : args . train_batch_size ,
" ss_total_batch_size " : total_batch_size ,
2023-01-22 15:18:00 +00:00
" ss_gradient_checkpointing " : args . gradient_checkpointing ,
2023-01-15 16:05:22 +00:00
" ss_gradient_accumulation_steps " : args . gradient_accumulation_steps ,
" ss_max_train_steps " : args . max_train_steps ,
" ss_lr_warmup_steps " : args . lr_warmup_steps ,
" ss_lr_scheduler " : args . lr_scheduler ,
" ss_network_module " : args . network_module ,
2023-01-22 15:18:00 +00:00
" ss_network_dim " : args . network_dim , # None means default because another network than LoRA may have another default dim
" ss_network_alpha " : args . network_alpha , # some networks may not use this value
2023-01-15 16:05:22 +00:00
" ss_mixed_precision " : args . mixed_precision ,
" ss_full_fp16 " : bool ( args . full_fp16 ) ,
" ss_v2 " : bool ( args . v2 ) ,
" ss_resolution " : args . resolution ,
" ss_clip_skip " : args . clip_skip ,
" ss_max_token_length " : args . max_token_length ,
" ss_color_aug " : bool ( args . color_aug ) ,
" ss_flip_aug " : bool ( args . flip_aug ) ,
" ss_random_crop " : bool ( args . random_crop ) ,
" ss_shuffle_caption " : bool ( args . shuffle_caption ) ,
" ss_cache_latents " : bool ( args . cache_latents ) ,
2023-01-22 15:18:00 +00:00
" ss_enable_bucket " : bool ( train_dataset . enable_bucket ) ,
" ss_min_bucket_reso " : train_dataset . min_bucket_reso ,
" ss_max_bucket_reso " : train_dataset . max_bucket_reso ,
" ss_seed " : args . seed ,
" ss_keep_tokens " : args . keep_tokens ,
2023-02-19 11:50:33 +00:00
" ss_noise_offset " : args . noise_offset ,
2023-01-22 15:18:00 +00:00
" ss_dataset_dirs " : json . dumps ( train_dataset . dataset_dirs_info ) ,
" ss_reg_dataset_dirs " : json . dumps ( train_dataset . reg_dataset_dirs_info ) ,
2023-02-03 19:40:03 +00:00
" ss_tag_frequency " : json . dumps ( train_dataset . tag_frequency ) ,
2023-01-26 21:22:58 +00:00
" ss_bucket_info " : json . dumps ( train_dataset . bucket_info ) ,
2023-02-14 23:52:08 +00:00
" ss_training_comment " : args . training_comment , # will not be updated after training
2023-02-19 11:50:33 +00:00
" ss_sd_scripts_commit_hash " : train_util . get_git_revision_hash ( ) ,
" ss_optimizer " : optimizer_name
2023-01-15 16:05:22 +00:00
}
# uncomment if another network is added
# for key, value in net_kwargs.items():
# metadata["ss_arg_" + key] = value
if args . pretrained_model_name_or_path is not None :
sd_model_name = args . pretrained_model_name_or_path
if os . path . exists ( sd_model_name ) :
metadata [ " ss_sd_model_hash " ] = train_util . model_hash ( sd_model_name )
2023-01-22 15:18:00 +00:00
metadata [ " ss_new_sd_model_hash " ] = train_util . calculate_sha256 ( sd_model_name )
2023-01-15 16:05:22 +00:00
sd_model_name = os . path . basename ( sd_model_name )
metadata [ " ss_sd_model_name " ] = sd_model_name
if args . vae is not None :
vae_name = args . vae
if os . path . exists ( vae_name ) :
metadata [ " ss_vae_hash " ] = train_util . model_hash ( vae_name )
2023-01-22 15:18:00 +00:00
metadata [ " ss_new_vae_hash " ] = train_util . calculate_sha256 ( vae_name )
2023-01-15 16:05:22 +00:00
vae_name = os . path . basename ( vae_name )
metadata [ " ss_vae_name " ] = vae_name
metadata = { k : str ( v ) for k , v in metadata . items ( ) }
2022-12-26 13:47:33 +00:00
progress_bar = tqdm ( range ( args . max_train_steps ) , smoothing = 0 , disable = not accelerator . is_local_main_process , desc = " steps " )
global_step = 0
noise_scheduler = DDPMScheduler ( beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = " scaled_linear " ,
num_train_timesteps = 1000 , clip_sample = False )
if accelerator . is_main_process :
accelerator . init_trackers ( " network_train " )
2023-02-19 11:50:33 +00:00
loss_list = [ ]
loss_total = 0.0
2022-12-26 13:47:33 +00:00
for epoch in range ( num_train_epochs ) :
print ( f " epoch { epoch + 1 } / { num_train_epochs } " )
2023/02/09 (v20.7.1)
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).
- ``--caption_dropout_every_n_epochs`` option specifies how many epochs to drop captions. If ``3`` is specified, in epoch 3, 6, 9 ..., images are trained with all captions empty. Default is None (no dropout).
- ``--caption_tag_dropout_rate`` option specified the dropout rate for tags (comma separated tokens) (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the tag is removed from the caption. If ``--keep_tokens`` option is set, these tokens (tags) are not dropped. Default is 0 (no droupout).
- The bulk image downsampling script is added. Documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) (in Jpanaese). Thanks to bmaltais!
- Typo check is added. Thanks to shirayu!
- Add option to autolaunch the GUI in a browser and set the server_port. USe either `gui.ps1 --inbrowser --server_port 3456`or `gui.cmd -inbrowser -server_port 3456`
2023-02-10 00:17:24 +00:00
train_dataset . set_current_epoch ( epoch + 1 )
2023-02-08 01:58:35 +00:00
2023-01-15 16:05:22 +00:00
metadata [ " ss_epoch " ] = str ( epoch + 1 )
2022-12-26 13:47:33 +00:00
network . on_epoch_start ( text_encoder , unet )
for step , batch in enumerate ( train_dataloader ) :
with accelerator . accumulate ( network ) :
with torch . no_grad ( ) :
2023-01-09 12:47:07 +00:00
if " latents " in batch and batch [ " latents " ] is not None :
2022-12-26 13:47:33 +00:00
latents = batch [ " latents " ] . to ( accelerator . device )
else :
2023-01-09 12:47:07 +00:00
# latentに変換
2022-12-26 13:47:33 +00:00
latents = vae . encode ( batch [ " images " ] . to ( dtype = weight_dtype ) ) . latent_dist . sample ( )
latents = latents * 0.18215
b_size = latents . shape [ 0 ]
with torch . set_grad_enabled ( train_text_encoder ) :
# Get the text embedding for conditioning
input_ids = batch [ " input_ids " ] . to ( accelerator . device )
2023-01-09 12:47:07 +00:00
encoder_hidden_states = train_util . get_hidden_states ( args , input_ids , tokenizer , text_encoder , weight_dtype )
2022-12-26 13:47:33 +00:00
# Sample noise that we'll add to the latents
noise = torch . randn_like ( latents , device = latents . device )
2023-02-14 23:52:08 +00:00
if args . noise_offset :
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise + = args . noise_offset * torch . randn ( ( latents . shape [ 0 ] , latents . shape [ 1 ] , 1 , 1 ) , device = latents . device )
2022-12-26 13:47:33 +00:00
# Sample a random timestep for each image
timesteps = torch . randint ( 0 , noise_scheduler . config . num_train_timesteps , ( b_size , ) , device = latents . device )
timesteps = timesteps . long ( )
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
# Predict the noise residual
2023-02-14 23:52:08 +00:00
with autocast ( ) :
noise_pred = unet ( noisy_latents , timesteps , encoder_hidden_states ) . sample
2022-12-26 13:47:33 +00:00
if args . v_parameterization :
# v-parameterization training
target = noise_scheduler . get_velocity ( latents , noise , timesteps )
else :
target = noise
loss = torch . nn . functional . mse_loss ( noise_pred . float ( ) , target . float ( ) , reduction = " none " )
loss = loss . mean ( [ 1 , 2 , 3 ] )
loss_weights = batch [ " loss_weights " ] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss . mean ( ) # 平均なのでbatch_sizeで割る必要なし
accelerator . backward ( loss )
if accelerator . sync_gradients :
params_to_clip = network . get_trainable_params ( )
accelerator . clip_grad_norm_ ( params_to_clip , 1.0 ) # args.max_grad_norm)
optimizer . step ( )
lr_scheduler . step ( )
optimizer . zero_grad ( set_to_none = True )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator . sync_gradients :
progress_bar . update ( 1 )
global_step + = 1
current_loss = loss . detach ( ) . item ( )
2023-02-19 11:50:33 +00:00
if epoch == 0 :
loss_list . append ( current_loss )
else :
loss_total - = loss_list [ step ]
loss_list [ step ] = current_loss
2022-12-26 13:47:33 +00:00
loss_total + = current_loss
2023-02-19 11:50:33 +00:00
avr_loss = loss_total / len ( loss_list )
2022-12-26 13:47:33 +00:00
logs = { " loss " : avr_loss } # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar . set_postfix ( * * logs )
2023-01-22 15:18:00 +00:00
if args . logging_dir is not None :
logs = generate_step_logs ( args , current_loss , avr_loss , lr_scheduler )
accelerator . log ( logs , step = global_step )
2022-12-26 13:47:33 +00:00
if global_step > = args . max_train_steps :
break
if args . logging_dir is not None :
2023-02-19 11:50:33 +00:00
logs = { " loss/epoch " : loss_total / len ( loss_list ) }
2022-12-26 13:47:33 +00:00
accelerator . log ( logs , step = epoch + 1 )
accelerator . wait_for_everyone ( )
if args . save_every_n_epochs is not None :
2023-01-09 12:47:07 +00:00
model_name = train_util . DEFAULT_EPOCH_NAME if args . output_name is None else args . output_name
def save_func ( ) :
ckpt_name = train_util . EPOCH_FILE_NAME . format ( model_name , epoch + 1 ) + ' . ' + args . save_model_as
ckpt_file = os . path . join ( args . output_dir , ckpt_name )
print ( f " saving checkpoint: { ckpt_file } " )
2023-01-15 16:05:22 +00:00
unwrap_model ( network ) . save_weights ( ckpt_file , save_dtype , None if args . no_metadata else metadata )
2022-12-26 13:47:33 +00:00
2023-01-09 12:47:07 +00:00
def remove_old_func ( old_epoch_no ) :
old_ckpt_name = train_util . EPOCH_FILE_NAME . format ( model_name , old_epoch_no ) + ' . ' + args . save_model_as
old_ckpt_file = os . path . join ( args . output_dir , old_ckpt_name )
if os . path . exists ( old_ckpt_file ) :
print ( f " removing old checkpoint: { old_ckpt_file } " )
os . remove ( old_ckpt_file )
2023-01-19 20:47:43 +00:00
saving = train_util . save_on_epoch_end ( args , save_func , remove_old_func , epoch + 1 , num_train_epochs )
2023-01-09 12:47:07 +00:00
if saving and args . save_state :
2023-01-19 20:47:43 +00:00
train_util . save_state_on_epoch_end ( args , accelerator , model_name , epoch + 1 )
2023-01-09 12:47:07 +00:00
# end of epoch
2022-12-26 13:47:33 +00:00
2023-01-15 16:05:22 +00:00
metadata [ " ss_epoch " ] = str ( num_train_epochs )
2022-12-26 13:47:33 +00:00
is_main_process = accelerator . is_main_process
if is_main_process :
network = unwrap_model ( network )
accelerator . end_training ( )
if args . save_state :
2023-01-09 12:47:07 +00:00
train_util . save_state_on_train_end ( args , accelerator )
2022-12-26 13:47:33 +00:00
del accelerator # この後メモリを使うのでこれは消す
if is_main_process :
os . makedirs ( args . output_dir , exist_ok = True )
2023-01-09 12:47:07 +00:00
model_name = train_util . DEFAULT_LAST_OUTPUT_NAME if args . output_name is None else args . output_name
ckpt_name = model_name + ' . ' + args . save_model_as
ckpt_file = os . path . join ( args . output_dir , ckpt_name )
2022-12-26 13:47:33 +00:00
print ( f " save trained model to { ckpt_file } " )
2023-01-15 16:05:22 +00:00
network . save_weights ( ckpt_file , save_dtype , None if args . no_metadata else metadata )
2022-12-26 13:47:33 +00:00
print ( " model saved. " )
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( )
2023-01-09 12:47:07 +00:00
train_util . add_sd_models_arguments ( parser )
2023-02-08 01:58:35 +00:00
train_util . add_dataset_arguments ( parser , True , True , True )
2023-01-09 12:47:07 +00:00
train_util . add_training_arguments ( parser , True )
2023-01-15 16:05:22 +00:00
parser . add_argument ( " --no_metadata " , action = ' store_true ' , help = " do not save metadata in output model / メタデータを出力先モデルに保存しない " )
2023-01-26 21:22:58 +00:00
parser . add_argument ( " --save_model_as " , type = str , default = " safetensors " , choices = [ None , " ckpt " , " pt " , " safetensors " ] ,
help = " format to save the model (default is .safetensors) / モデル保存時の形式( デフォルトはsafetensors) " )
2023-01-09 12:47:07 +00:00
2022-12-26 13:47:33 +00:00
parser . add_argument ( " --unet_lr " , type = float , default = None , help = " learning rate for U-Net / U-Netの学習率 " )
parser . add_argument ( " --text_encoder_lr " , type = float , default = None , help = " learning rate for Text Encoder / Text Encoderの学習率 " )
2023-01-29 16:10:06 +00:00
parser . add_argument ( " --lr_scheduler_num_cycles " , type = int , default = 1 ,
help = " Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数 " )
parser . add_argument ( " --lr_scheduler_power " , type = float , default = 1 ,
help = " Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power " )
2023-01-09 12:47:07 +00:00
parser . add_argument ( " --network_weights " , type = str , default = None ,
help = " pretrained weights for network / 学習するネットワークの初期重み " )
2022-12-26 13:47:33 +00:00
parser . add_argument ( " --network_module " , type = str , default = None , help = ' network module to train / 学習対象のネットワークのモジュール ' )
parser . add_argument ( " --network_dim " , type = int , default = None ,
help = ' network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります) ' )
2023-01-22 15:18:00 +00:00
parser . add_argument ( " --network_alpha " , type = float , default = 1 ,
help = ' alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1( 旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定) ' )
2022-12-26 13:47:33 +00:00
parser . add_argument ( " --network_args " , type = str , default = None , nargs = ' * ' ,
help = ' additional argmuments for network (key=value) / ネットワークへの追加の引数 ' )
parser . add_argument ( " --network_train_unet_only " , action = " store_true " , help = " only training U-Net part / U-Net関連部分のみ学習する " )
parser . add_argument ( " --network_train_text_encoder_only " , action = " store_true " ,
help = " only training Text Encoder part / Text Encoder関連部分のみ学習する " )
2023-01-22 15:18:00 +00:00
parser . add_argument ( " --training_comment " , type = str , default = None ,
help = " arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列 " )
2022-12-26 13:47:33 +00:00
args = parser . parse_args ( )
2023-01-09 12:47:07 +00:00
train ( args )