Integrate new kohya sd-script
This commit is contained in:
parent
a1f6438f7b
commit
641a168e55
@ -145,6 +145,7 @@ Then redo the installation instruction within the kohya_ss venv.
|
||||
|
||||
* 2023/02/15 (v20.7.3):
|
||||
- Update upgrade.ps1 script
|
||||
- Integrate new kohya sd-script
|
||||
* 2023/02/11 (v20.7.2):
|
||||
- `lora_interrogator.py` is added in `networks` folder. See `python networks\lora_interrogator.py -h` for usage.
|
||||
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.
|
||||
|
@ -255,6 +255,9 @@ def train(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
|
@ -12,6 +12,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
|
||||
from tqdm import tqdm
|
||||
@ -299,7 +300,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if self.shuffle_keep_tokens is None:
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
else:
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
@ -308,7 +309,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
|
||||
tokens = keep_tokens + tokens
|
||||
@ -1100,6 +1101,13 @@ def addnet_hash_safetensors(b):
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def get_git_revision_hash() -> str:
|
||||
try:
|
||||
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
|
||||
except:
|
||||
return "(unknown)"
|
||||
|
||||
|
||||
# flash attention forwards and backwards
|
||||
|
||||
# https://arxiv.org/abs/2205.14135
|
||||
@ -1413,6 +1421,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||
parser.add_argument("--noise_offset", type=float, default=None,
|
||||
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
@ -1620,9 +1630,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
else:
|
||||
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
||||
if weight_dtype is not None:
|
||||
# this is required for additional network training
|
||||
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||
|
||||
# bs*3, 77, 768 or 1024
|
||||
@ -1649,6 +1656,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
|
||||
if weight_dtype is not None:
|
||||
# this is required for additional network training
|
||||
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
|
@ -38,9 +38,10 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
verbose_str = "\n"
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
@ -96,6 +97,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
||||
|
||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
||||
|
||||
if verbose:
|
||||
s_sum = torch.sum(torch.abs(S))
|
||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||
verbose_str+=f"{block_down_name:76} | "
|
||||
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
|
||||
|
||||
U = U[:, :new_rank]
|
||||
S = S[:new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
@ -113,7 +120,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
||||
U = U.unsqueeze(2).unsqueeze(3)
|
||||
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
if args.device:
|
||||
if device:
|
||||
U = U.to(org_device)
|
||||
Vh = Vh.to(org_device)
|
||||
|
||||
@ -127,6 +134,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
||||
lora_up_weight = None
|
||||
weights_loaded = False
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
print("resizing complete")
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
|
||||
@ -151,7 +160,7 @@ def resize(args):
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
|
||||
print("resizing rank...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device)
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
||||
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
@ -182,6 +191,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--model", type=str, default=None,
|
||||
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
||||
|
||||
args = parser.parse_args()
|
||||
resize(args)
|
||||
|
@ -233,10 +233,13 @@ def train(args):
|
||||
else:
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
b_size = latents.shape[0]
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
|
@ -1,5 +1,7 @@
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from torch.optim import Optimizer
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from typing import Optional, Union
|
||||
import importlib
|
||||
import argparse
|
||||
@ -154,7 +156,9 @@ def train(args):
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
||||
|
||||
# unnecessary, but work on low-ram device
|
||||
text_encoder.to("cuda")
|
||||
unet.to("cuda")
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
@ -258,17 +262,26 @@ def train(args):
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device)
|
||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
text_encoder.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
if type(text_encoder) == DDP:
|
||||
text_encoder.module.text_model.embeddings.requires_grad_(True)
|
||||
else:
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
else:
|
||||
unet.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
# support DistributedDataParallel
|
||||
if type(text_encoder) == DDP:
|
||||
text_encoder = text_encoder.module
|
||||
unet = unet.module
|
||||
network = network.module
|
||||
|
||||
network.prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if not cache_latents:
|
||||
@ -344,7 +357,8 @@ def train(args):
|
||||
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
||||
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
||||
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
||||
"ss_training_comment": args.training_comment # will not be updated after training
|
||||
"ss_training_comment": args.training_comment, # will not be updated after training
|
||||
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
|
||||
}
|
||||
|
||||
# uncomment if another network is added
|
||||
@ -405,6 +419,9 @@ def train(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
@ -415,7 +432,8 @@ def train(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
with autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
|
@ -320,6 +320,9 @@ def train(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
|
Loading…
Reference in New Issue
Block a user