Integrate new kohya sd-script

This commit is contained in:
bmaltais 2023-02-14 18:52:08 -05:00
parent a1f6438f7b
commit 641a168e55
7 changed files with 64 additions and 14 deletions

View File

@ -145,6 +145,7 @@ Then redo the installation instruction within the kohya_ss venv.
* 2023/02/15 (v20.7.3): * 2023/02/15 (v20.7.3):
- Update upgrade.ps1 script - Update upgrade.ps1 script
- Integrate new kohya sd-script
* 2023/02/11 (v20.7.2): * 2023/02/11 (v20.7.2):
- `lora_interrogator.py` is added in `networks` folder. See `python networks\lora_interrogator.py -h` for usage. - `lora_interrogator.py` is added in `networks` folder. See `python networks\lora_interrogator.py -h` for usage.
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate. - For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.

View File

@ -255,6 +255,9 @@ def train(args):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)

View File

@ -12,6 +12,7 @@ import math
import os import os
import random import random
import hashlib import hashlib
import subprocess
from io import BytesIO from io import BytesIO
from tqdm import tqdm from tqdm import tqdm
@ -299,7 +300,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.shuffle_keep_tokens is None: if self.shuffle_keep_tokens is None:
if self.shuffle_caption: if self.shuffle_caption:
random.shuffle(tokens) random.shuffle(tokens)
tokens = dropout_tags(tokens) tokens = dropout_tags(tokens)
else: else:
if len(tokens) > self.shuffle_keep_tokens: if len(tokens) > self.shuffle_keep_tokens:
@ -308,7 +309,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.shuffle_caption: if self.shuffle_caption:
random.shuffle(tokens) random.shuffle(tokens)
tokens = dropout_tags(tokens) tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens tokens = keep_tokens + tokens
@ -1100,6 +1101,13 @@ def addnet_hash_safetensors(b):
return hash_sha256.hexdigest() return hash_sha256.hexdigest()
def get_git_revision_hash() -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
except:
return "(unknown)"
# flash attention forwards and backwards # flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135 # https://arxiv.org/abs/2205.14135
@ -1413,6 +1421,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0, parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0") help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
parser.add_argument("--noise_offset", type=float, default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)")
if support_dreambooth: if support_dreambooth:
# DreamBooth training # DreamBooth training
@ -1620,9 +1630,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
else: else:
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
if weight_dtype is not None:
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024 # bs*3, 77, 768 or 1024
@ -1649,6 +1656,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1) encoder_hidden_states = torch.cat(states_list, dim=1)
if weight_dtype is not None:
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
return encoder_hidden_states return encoder_hidden_states

View File

@ -38,9 +38,10 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
torch.save(model, file_name) torch.save(model, file_name)
def resize_lora_model(lora_sd, new_rank, save_dtype, device): def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
network_alpha = None network_alpha = None
network_dim = None network_dim = None
verbose_str = "\n"
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99
@ -96,6 +97,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
U, S, Vh = torch.linalg.svd(full_weight_matrix) U, S, Vh = torch.linalg.svd(full_weight_matrix)
if verbose:
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
verbose_str+=f"{block_down_name:76} | "
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
U = U[:, :new_rank] U = U[:, :new_rank]
S = S[:new_rank] S = S[:new_rank]
U = U @ torch.diag(S) U = U @ torch.diag(S)
@ -113,7 +120,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
U = U.unsqueeze(2).unsqueeze(3) U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3) Vh = Vh.unsqueeze(2).unsqueeze(3)
if args.device: if device:
U = U.to(org_device) U = U.to(org_device)
Vh = Vh.to(org_device) Vh = Vh.to(org_device)
@ -127,6 +134,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
lora_up_weight = None lora_up_weight = None
weights_loaded = False weights_loaded = False
if verbose:
print(verbose_str)
print("resizing complete") print("resizing complete")
return o_lora_sd, network_dim, new_alpha return o_lora_sd, network_dim, new_alpha
@ -151,7 +160,7 @@ def resize(args):
lora_sd, metadata = load_state_dict(args.model, merge_dtype) lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print("resizing rank...") print("resizing rank...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device) state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
# update metadata # update metadata
if metadata is None: if metadata is None:
@ -182,6 +191,8 @@ if __name__ == '__main__':
parser.add_argument("--model", type=str, default=None, parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
args = parser.parse_args() args = parser.parse_args()
resize(args) resize(args)

View File

@ -233,10 +233,13 @@ def train(args):
else: else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0] if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Get the text embedding for conditioning # Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):

View File

@ -1,5 +1,7 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Optional, Union from typing import Optional, Union
import importlib import importlib
import argparse import argparse
@ -154,7 +156,9 @@ def train(args):
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# unnecessary, but work on low-ram device
text_encoder.to("cuda")
unet.to("cuda")
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@ -258,17 +262,26 @@ def train(args):
unet.requires_grad_(False) unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train() unet.train()
text_encoder.train() text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works # set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True) if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
else: else:
unet.eval() unet.eval()
text_encoder.eval() text_encoder.eval()
# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
network.prepare_grad_etc(text_encoder, unet) network.prepare_grad_etc(text_encoder, unet)
if not cache_latents: if not cache_latents:
@ -344,7 +357,8 @@ def train(args):
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency), "ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
"ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment # will not be updated after training "ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
} }
# uncomment if another network is added # uncomment if another network is added
@ -405,6 +419,9 @@ def train(args):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@ -415,7 +432,8 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual # Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample with autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization: if args.v_parameterization:
# v-parameterization training # v-parameterization training

View File

@ -320,6 +320,9 @@ def train(args):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)