From 641a168e55f429c79f9114bcdb123a13bc9b2167 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 14 Feb 2023 18:52:08 -0500 Subject: [PATCH] Integrate new kohya sd-script --- README.md | 1 + fine_tune.py | 3 +++ library/train_util.py | 21 ++++++++++++++++----- networks/resize_lora.py | 17 ++++++++++++++--- train_db.py | 5 ++++- train_network.py | 28 +++++++++++++++++++++++----- train_textual_inversion.py | 3 +++ 7 files changed, 64 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index e18d3b8..6165627 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,7 @@ Then redo the installation instruction within the kohya_ss venv. * 2023/02/15 (v20.7.3): - Update upgrade.ps1 script + - Integrate new kohya sd-script * 2023/02/11 (v20.7.2): - `lora_interrogator.py` is added in `networks` folder. See `python networks\lora_interrogator.py -h` for usage. - For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate. diff --git a/fine_tune.py b/fine_tune.py index 5292153..3ba6306 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -255,6 +255,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/library/train_util.py b/library/train_util.py index 24e15d1..415f9b7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import math import os import random import hashlib +import subprocess from io import BytesIO from tqdm import tqdm @@ -299,7 +300,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.shuffle_keep_tokens is None: if self.shuffle_caption: random.shuffle(tokens) - + tokens = dropout_tags(tokens) else: if len(tokens) > self.shuffle_keep_tokens: @@ -308,7 +309,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.shuffle_caption: random.shuffle(tokens) - + tokens = dropout_tags(tokens) tokens = keep_tokens + tokens @@ -1100,6 +1101,13 @@ def addnet_hash_safetensors(b): return hash_sha256.hexdigest() +def get_git_revision_hash() -> str: + try: + return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + except: + return "(unknown)" + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 @@ -1413,6 +1421,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + parser.add_argument("--noise_offset", type=float, default=None, + help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") if support_dreambooth: # DreamBooth training @@ -1620,9 +1630,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod else: enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) # bs*3, 77, 768 or 1024 @@ -1649,6 +1656,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + return encoder_hidden_states diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 7beeb25..271de8e 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -38,9 +38,10 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) -def resize_lora_model(lora_sd, new_rank, save_dtype, device): +def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): network_alpha = None network_dim = None + verbose_str = "\n" CLAMP_QUANTILE = 0.99 @@ -96,6 +97,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device): U, S, Vh = torch.linalg.svd(full_weight_matrix) + if verbose: + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + verbose_str+=f"{block_down_name:76} | " + verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n" + U = U[:, :new_rank] S = S[:new_rank] U = U @ torch.diag(S) @@ -113,7 +120,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device): U = U.unsqueeze(2).unsqueeze(3) Vh = Vh.unsqueeze(2).unsqueeze(3) - if args.device: + if device: U = U.to(org_device) Vh = Vh.to(org_device) @@ -127,6 +134,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device): lora_up_weight = None weights_loaded = False + if verbose: + print(verbose_str) print("resizing complete") return o_lora_sd, network_dim, new_alpha @@ -151,7 +160,7 @@ def resize(args): lora_sd, metadata = load_state_dict(args.model, merge_dtype) print("resizing rank...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device) + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose) # update metadata if metadata is None: @@ -182,6 +191,8 @@ if __name__ == '__main__': parser.add_argument("--model", type=str, default=None, help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument("--verbose", action="store_true", + help="Display verbose resizing information / rank変更時の詳細情報を出力する") args = parser.parse_args() resize(args) diff --git a/train_db.py b/train_db.py index c210767..4a50dc9 100644 --- a/train_db.py +++ b/train_db.py @@ -233,10 +233,13 @@ def train(args): else: latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 + b_size = latents.shape[0] # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): diff --git a/train_network.py b/train_network.py index bb3159f..1b8046d 100644 --- a/train_network.py +++ b/train_network.py @@ -1,5 +1,7 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from torch.optim import Optimizer +from torch.cuda.amp import autocast +from torch.nn.parallel import DistributedDataParallel as DDP from typing import Optional, Union import importlib import argparse @@ -154,7 +156,9 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - + # unnecessary, but work on low-ram device + text_encoder.to("cuda") + unet.to("cuda") # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -258,17 +262,26 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() text_encoder.train() # set top parameter requires_grad = True for gradient checkpointing works - text_encoder.text_model.embeddings.requires_grad_(True) + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) + else: + text_encoder.text_model.embeddings.requires_grad_(True) else: unet.eval() text_encoder.eval() + # support DistributedDataParallel + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module + network.prepare_grad_etc(text_encoder, unet) if not cache_latents: @@ -344,7 +357,8 @@ def train(args): "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), "ss_bucket_info": json.dumps(train_dataset.bucket_info), - "ss_training_comment": args.training_comment # will not be updated after training + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash() } # uncomment if another network is added @@ -405,6 +419,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) @@ -415,7 +432,8 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4aa91ee..010bd04 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -320,6 +320,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)