Merge branch 'dev' into macos_gui

This commit is contained in:
bmaltais 2023-04-01 15:14:28 -04:00 committed by GitHub
commit ea003eba45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 3484 additions and 1450 deletions

3
.gitignore vendored
View File

@ -239,4 +239,5 @@ fabric.properties
.idea/httpRequests .idea/httpRequests
# Android studio 3.1+ serialized cache file # Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser .idea/caches/build_file_checksums.ser
library/__init__.py

View File

@ -64,36 +64,19 @@ cd kohya_ss
bash ubuntu_setup.sh bash ubuntu_setup.sh
``` ```
then configure accelerate with the same answers as in the Windows instructions when prompted. then configure accelerate with the same answers as in the MacOS instructions when prompted.
### Windows ### Windows
In the terminal, run
Give unrestricted script access to powershell so venv can work: ```
- Run PowerShell as an administrator
- Run `Set-ExecutionPolicy Unrestricted` and answer 'A'
- Close PowerShell
Open a regular user Powershell terminal and run the following commands:
```powershell
git clone https://github.com/bmaltais/kohya_ss.git git clone https://github.com/bmaltais/kohya_ss.git
cd kohya_ss cd kohya_ss
setup.bat
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --use-pep517 --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config
``` ```
then configure accelerate with the same answers as in the MacOS instructions when prompted.
### Optional: CUDNN 8.6 ### Optional: CUDNN 8.6
This step is optional but can improve the learning speed for NVIDIA 30X0/40X0 owners. It allows for larger training batch size and faster training speed. This step is optional but can improve the learning speed for NVIDIA 30X0/40X0 owners. It allows for larger training batch size and faster training speed.
@ -125,11 +108,7 @@ Once the commands have completed successfully you should be ready to use the new
When a new release comes out, you can upgrade your repo with the following commands in the root directory: When a new release comes out, you can upgrade your repo with the following commands in the root directory:
```powershell ```powershell
git pull upgrade.bat
.\venv\Scripts\activate
pip install --use-pep517 --upgrade -r requirements.txt
``` ```
Once the commands have completed successfully you should be ready to use the new version. Once the commands have completed successfully you should be ready to use the new version.
@ -213,6 +192,22 @@ This will store your a backup file with your current locally installed pip packa
## Change History ## Change History
* 2023/04/01 (v21.4.0)
- Fix an issue that `merge_lora.py` does not work with the latest version.
- Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights.
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)
- Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev!
- See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details.
- Use `train_textual_inversion_XTI.py` for training. The usage is almost the same as `train_textual_inversion.py`. However, sample image generation during training is not supported.
- Use `gen_img_diffusers.py` for image generation (I think Web UI is not supported). Specify the embedding with `--XTI_embeddings` option.
- Reduce RAM usage at startup in `train_network.py`. [#332](https://github.com/kohya-ss/sd-scripts/pull/332) Thank you guaneec!
- Support pre-merge for LoRA in `gen_img_diffusers.py`. Specify `--network_merge` option. Note that the `--am` option of the prompt option is no longer available with this option.
* 2023/04/01 (v21.3.9)
- Update how setup is done on Windows by introducing a setup.bat script. This will make it easier to install/re-install on Windows if needed. Many thanks to @missionfloyd for his PR: https://github.com/bmaltais/kohya_ss/pull/496
- Fix issue with WD14 caption script by applying a custom fix to kohya_ss code.
* 2023/03/30 (v21.3.8)
- Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481
* 2023/03/29 (v21.3.7) * 2023/03/29 (v21.3.7)
- Allow for 0.1 increment in Network and Conv alpha values: https://github.com/bmaltais/kohya_ss/pull/471 Thanks to @srndpty - Allow for 0.1 increment in Network and Conv alpha values: https://github.com/bmaltais/kohya_ss/pull/471 Thanks to @srndpty
- Updated Lycoris module version - Updated Lycoris module version

209
XTI_hijack.py Normal file
View File

@ -0,0 +1,209 @@
import torch
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
def unet_forward_XTI(self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
down_i = 0
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
)
down_i += 2
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
# 5. up
up_i = 7
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
upsample_size=upsample_size,
)
up_i += 3
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)
def downblock_forward_XTI(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
output_states = ()
i = 0
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
output_states += (hidden_states,)
i += 1
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
def upblock_forward_XTI(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
i = 0
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
i += 1
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states

View File

@ -0,0 +1,217 @@
import argparse
import csv
import glob
import os
from PIL import Image
import cv2
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
# import library.train_util as train_util
# from wd14 tagger
IMAGE_SIZE = 448
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == "*":
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
img_paths = list(set(img_paths)) # 重複を排除
img_paths.sort()
return img_paths
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
else:
print("using existing wd14 tagger model")
# 画像を読み込む
image_paths = glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
print("loading model and labels")
model = load_model(args.model_dir)
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
# 推論する
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh and i < len(tags):
tag_text += ", " + tags[i]
if len(tag_text) > 0:
tag_text = tag_text[2:] # 最初の ", " を消す
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(tag_text + '\n')
if args.debug:
print(image_path, tag_text)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is not None:
image = image.detach().numpy()
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
parser.add_argument("--force_download", action='store_true',
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)

View File

@ -95,6 +95,8 @@ import library.train_util as train_util
import tools.original_control_net as original_control_net import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo from tools.original_control_net import ControlNetInfo
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14" TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
@ -491,6 +493,9 @@ class PipelineLike:
# Textual Inversion # Textual Inversion
self.token_replacements = {} self.token_replacements = {}
# XTI
self.token_replacements_XTI = {}
# CLIP guidance # CLIP guidance
self.clip_guidance_scale = clip_guidance_scale self.clip_guidance_scale = clip_guidance_scale
self.clip_image_guidance_scale = clip_image_guidance_scale self.clip_image_guidance_scale = clip_image_guidance_scale
@ -514,15 +519,26 @@ class PipelineLike:
def add_token_replacement(self, target_token_id, rep_token_ids): def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids self.token_replacements[target_token_id] = rep_token_ids
def replace_token(self, tokens): def replace_token(self, tokens, layer=None):
new_tokens = [] new_tokens = []
for token in tokens: for token in tokens:
if token in self.token_replacements: if token in self.token_replacements:
new_tokens.extend(self.token_replacements[token]) replacer_ = self.token_replacements[token]
if layer:
replacer = []
for r in replacer_:
if r in self.token_replacements_XTI:
replacer.append(self.token_replacements_XTI[r][layer])
else:
replacer = replacer_
new_tokens.extend(replacer)
else: else:
new_tokens.append(token) new_tokens.append(token)
return new_tokens return new_tokens
def add_token_replacement_XTI(self, target_token_id, rep_token_ids):
self.token_replacements_XTI[target_token_id] = rep_token_ids
def set_control_nets(self, ctrl_nets): def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets self.control_nets = ctrl_nets
@ -744,14 +760,15 @@ class PipelineLike:
" the batch size of `prompt`." " the batch size of `prompt`."
) )
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( if not self.token_replacements_XTI:
pipe=self, text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
prompt=prompt, pipe=self,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None, prompt=prompt,
max_embeddings_multiples=max_embeddings_multiples, uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
clip_skip=self.clip_skip, max_embeddings_multiples=max_embeddings_multiples,
**kwargs, clip_skip=self.clip_skip,
) **kwargs,
)
if negative_scale is not None: if negative_scale is not None:
_, real_uncond_embeddings, _ = get_weighted_text_embeddings( _, real_uncond_embeddings, _ = get_weighted_text_embeddings(
@ -763,11 +780,47 @@ class PipelineLike:
**kwargs, **kwargs,
) )
if do_classifier_free_guidance: if self.token_replacements_XTI:
if negative_scale is None: text_embeddings_concat = []
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) for layer in [
else: "IN01",
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) "IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]:
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
layer=layer,
**kwargs,
)
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings]))
else:
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]))
text_embeddings = torch.stack(text_embeddings_concat)
else:
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
# CLIP guidanceで使用するembeddingsを取得する # CLIP guidanceで使用するembeddingsを取得する
if self.clip_guidance_scale > 0: if self.clip_guidance_scale > 0:
@ -1675,7 +1728,7 @@ def parse_prompt_attention(text):
return res return res
def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int): def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None):
r""" r"""
Tokenize a list of prompts and return its tokens with weights of each token. Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included. No padding, starting or ending token is included.
@ -1691,7 +1744,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
# tokenize and discard the starting and the ending token # tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1] token = pipe.tokenizer(word).input_ids[1:-1]
token = pipe.replace_token(token) token = pipe.replace_token(token, layer=layer)
text_token += token text_token += token
# copy the weight by length of token # copy the weight by length of token
@ -1807,6 +1860,7 @@ def get_weighted_text_embeddings(
skip_parsing: Optional[bool] = False, skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False, skip_weighting: Optional[bool] = False,
clip_skip=None, clip_skip=None,
layer=None,
**kwargs, **kwargs,
): ):
r""" r"""
@ -1837,11 +1891,11 @@ def get_weighted_text_embeddings(
prompt = [prompt] prompt = [prompt]
if not skip_parsing: if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
if uncond_prompt is not None: if uncond_prompt is not None:
if isinstance(uncond_prompt, str): if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt] uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer)
else: else:
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens] prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
@ -2229,13 +2283,17 @@ def main(args):
if network is None: if network is None:
return return
network.apply_to(text_encoder, unet) if not args.network_merge:
network.apply_to(text_encoder, unet)
if args.opt_channels_last: if args.opt_channels_last:
network.to(memory_format=torch.channels_last) network.to(memory_format=torch.channels_last)
network.to(dtype).to(device) network.to(dtype).to(device)
networks.append(network)
else:
network.merge_to(text_encoder, unet, dtype, device)
networks.append(network)
else: else:
networks = [] networks = []
@ -2289,6 +2347,11 @@ def main(args):
if args.diffusers_xformers: if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# Textual Inversionを処理する # Textual Inversionを処理する
if args.textual_inversion_embeddings: if args.textual_inversion_embeddings:
token_ids_embeds = [] token_ids_embeds = []
@ -2335,6 +2398,71 @@ def main(args):
for token_id, embed in zip(token_ids, embeds): for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed token_embeds[token_id] = embed
if args.XTI_embeddings:
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
token_ids_embeds_XTI = []
for embeds_file in args.XTI_embeddings:
if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file
data = load_file(embeds_file)
else:
data = torch.load(embeds_file, map_location="cpu")
if set(data.keys()) != set(XTI_layers):
raise ValueError("NOT XTI")
embeds = torch.concat(list(data.values()))
num_vectors_per_token = data["MID"].size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == num_vectors_per_token
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
# if num_vectors_per_token > 1:
pipe.add_token_replacement(token_ids[0], token_ids)
token_strings_XTI = []
for layer_name in XTI_layers:
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
token_ids_embeds_XTI.append((token_ids_XTI, embeds))
for t in token_ids:
t_XTI_dic = {}
for i, layer_name in enumerate(XTI_layers):
t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens
pipe.add_token_replacement_XTI(t, t_XTI_dic)
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
for token_ids, embeds in token_ids_embeds_XTI:
for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed
# promptを取得する # promptを取得する
if args.from_file is not None: if args.from_file is not None:
print(f"reading prompts from {args.from_file}") print(f"reading prompts from {args.from_file}")
@ -2983,6 +3111,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
) )
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument( parser.add_argument(
"--textual_inversion_embeddings", "--textual_inversion_embeddings",
type=str, type=str,
@ -2990,6 +3119,13 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*", nargs="*",
help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", help="Embeddings files of Textual Inversion / Textual Inversionのembeddings",
) )
parser.add_argument(
"--XTI_embeddings",
type=str,
default=None,
nargs="*",
help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings",
)
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
parser.add_argument( parser.add_argument(
"--max_embeddings_multiples", "--max_embeddings_multiples",

View File

File diff suppressed because it is too large Load Diff

View File

@ -404,6 +404,8 @@ class BaseDataset(torch.utils.data.Dataset):
self.token_padding_disabled = False self.token_padding_disabled = False
self.tag_frequency = {} self.tag_frequency = {}
self.XTI_layers = None
self.token_strings = None
self.enable_bucket = False self.enable_bucket = False
self.bucket_manager: BucketManager = None # not initialized self.bucket_manager: BucketManager = None # not initialized
@ -464,6 +466,10 @@ class BaseDataset(torch.utils.data.Dataset):
def disable_token_padding(self): def disable_token_padding(self):
self.token_padding_disabled = True self.token_padding_disabled = True
def enable_XTI(self, layers=None, token_strings=None):
self.XTI_layers = layers
self.token_strings = token_strings
def add_replacement(self, str_from, str_to): def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to self.replacements[str_from] = str_to
@ -909,9 +915,22 @@ class BaseDataset(torch.utils.data.Dataset):
latents_list.append(latents) latents_list.append(latents)
caption = self.process_caption(subset, image_info.caption) caption = self.process_caption(subset, image_info.caption)
captions.append(caption) if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption)) if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer)
else:
token_caption = self.get_input_ids(caption)
input_ids_list.append(token_caption)
example = {} example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights) example["loss_weights"] = torch.FloatTensor(loss_weights)
@ -1314,6 +1333,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# for dataset in self.datasets: # for dataset in self.datasets:
# dataset.make_buckets() # dataset.make_buckets()
def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1): def cache_latents(self, vae, vae_batch_size=1):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
@ -2617,14 +2640,15 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype): def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
name_or_path = args.pretrained_model_name_or_path name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format: if load_stable_diffusion_format:
print("load StableDiffusion checkpoint") print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
else: else:
# Diffusers model is loaded to CPU
print("load Diffusers pretrained models") print("load Diffusers pretrained models")
try: try:
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)

View File

@ -34,7 +34,7 @@ def caption_images(
return return
print(f'Captioning files in {train_data_dir}...') print(f'Captioning files in {train_data_dir}...')
run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"' run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger_bmaltais.py"'
run_cmd += f' --batch_size="{int(batch_size)}"' run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --thresh="{thresh}"' run_cmd += f' --thresh="{thresh}"'
run_cmd += f' --caption_extension="{caption_extension}"' run_cmd += f' --caption_extension="{caption_extension}"'

View File

@ -13,386 +13,471 @@ from library import train_util
class LoRAModule(torch.nn.Module): class LoRAModule(torch.nn.Module):
""" """
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
""" """
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
""" if alpha == 0 or None, alpha is rank (no scaling). """ """if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels in_dim = org_module.in_channels
out_dim = org_module.out_channels out_dim = org_module.out_channels
else: else:
in_dim = org_module.in_features in_dim = org_module.in_features
out_dim = org_module.out_features out_dim = org_module.out_features
# if limit_rank: # if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim) # self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim: # if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else: # else:
self.lora_dim = lora_dim self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size kernel_size = org_module.kernel_size
stride = org_module.stride stride = org_module.stride
padding = org_module.padding padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else: else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor: if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's # same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight) torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier self.multiplier = multiplier
self.org_module = org_module # remove in applying self.org_module = org_module # remove in applying
self.region = None self.region = None
self.region_mask = None self.region_mask = None
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module.forward self.org_forward = self.org_module.forward
self.org_module.forward = self.forward self.org_module.forward = self.forward
del self.org_module del self.org_module
def set_region(self, region): def merge_to(self, sd, dtype, device):
self.region = region # get up/down weight
self.region_mask = None up_weight = sd["lora_up.weight"].to(torch.float).to(device)
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
def forward(self, x): # extract weight from org_module
if self.region is None: org_sd = self.org_module.state_dict()
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale weight = org_sd["weight"].to(torch.float)
# regional LoRA FIXME same as additional-network extension # merge weight
if x.size()[1] % 77 == 0: if len(weight.size()) == 2:
# print(f"LoRA for context: {self.lora_name}") # linear
self.region = None weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# calculate region mask first time # set weight to org_module
if self.region_mask is None: org_sd["weight"] = weight.to(dtype)
if len(x.size()) == 4: self.org_module.load_state_dict(org_sd)
h, w = x.size()[2:4]
else:
seq_len = x.size()[1]
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
h = int(self.region.size()[0] / ratio + .5)
w = seq_len // h
r = self.region.to(x.device) def set_region(self, region):
if r.dtype == torch.bfloat16: self.region = region
r = r.to(torch.float) self.region_mask = None
r = r.unsqueeze(0).unsqueeze(1)
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
r = r.to(x.dtype)
if len(x.size()) == 3: def forward(self, x):
r = torch.reshape(r, (1, x.size()[1], -1)) if self.region is None:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
self.region_mask = r # regional LoRA FIXME same as additional-network extension
if x.size()[1] % 77 == 0:
# print(f"LoRA for context: {self.lora_name}")
self.region = None
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask # calculate region mask first time
if self.region_mask is None:
if len(x.size()) == 4:
h, w = x.size()[2:4]
else:
seq_len = x.size()[1]
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
h = int(self.region.size()[0] / ratio + 0.5)
w = seq_len // h
r = self.region.to(x.device)
if r.dtype == torch.bfloat16:
r = r.to(torch.float)
r = r.unsqueeze(0).unsqueeze(1)
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
r = r.to(x.dtype)
if len(x.size()) == 3:
r = torch.reshape(r, (1, x.size()[1], -1))
self.region_mask = r
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
# extract dim/alpha for conv2d, and block dim # extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get('conv_dim', None) conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get('conv_alpha', None) conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None: if conv_dim is not None:
conv_dim = int(conv_dim) conv_dim = int(conv_dim)
if conv_alpha is None: if conv_alpha is None:
conv_alpha = 1.0 conv_alpha = 1.0
else: else:
conv_alpha = float(conv_alpha) conv_alpha = float(conv_alpha)
""" """
block_dims = kwargs.get("block_dims") block_dims = kwargs.get("block_dims")
block_alphas = None block_alphas = None
if block_dims is not None: if block_dims is not None:
block_dims = [int(d) for d in block_dims.split(',')] block_dims = [int(d) for d in block_dims.split(',')]
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
block_alphas = kwargs.get("block_alphas") block_alphas = kwargs.get("block_alphas")
if block_alphas is None: if block_alphas is None:
block_alphas = [1] * len(block_dims) block_alphas = [1] * len(block_dims)
else: else:
block_alphas = [int(a) for a in block_alphas(',')] block_alphas = [int(a) for a in block_alphas(',')]
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
conv_block_dims = kwargs.get("conv_block_dims") conv_block_dims = kwargs.get("conv_block_dims")
conv_block_alphas = None conv_block_alphas = None
if conv_block_dims is not None: if conv_block_dims is not None:
conv_block_dims = [int(d) for d in conv_block_dims.split(',')] conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
conv_block_alphas = kwargs.get("conv_block_alphas") conv_block_alphas = kwargs.get("conv_block_alphas")
if conv_block_alphas is None: if conv_block_alphas is None:
conv_block_alphas = [1] * len(conv_block_dims) conv_block_alphas = [1] * len(conv_block_dims)
else: else:
conv_block_alphas = [int(a) for a in conv_block_alphas(',')] conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
""" """
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, network = LoRANetwork(
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha) text_encoder,
return network unet,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
if weights_sd is None: if weights_sd is None:
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location='cpu')
# get dim/alpha mapping weights_sd = load_file(file)
modules_dim = {} else:
modules_alpha = {} weights_sd = torch.load(file, map_location="cpu")
for key, value in weights_sd.items():
if '.' not in key:
continue
lora_name = key.split('.')[0] # get dim/alpha mapping
if 'alpha' in key: modules_dim = {}
modules_alpha[lora_name] = value modules_alpha = {}
elif 'lora_down' in key: for key, value in weights_sd.items():
dim = value.size()[0] if "." not in key:
modules_dim[lora_name] = dim continue
# print(lora_name, value.size(), dim)
# support old LoRA without alpha lora_name = key.split(".")[0]
for key in modules_dim.keys(): if "alpha" in key:
if key not in modules_alpha: modules_alpha[lora_name] = value
modules_alpha = modules_dim[key] elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) # support old LoRA without alpha
network.weights_sd = weights_sd for key in modules_dim.keys():
return network if key not in modules_alpha:
modules_alpha = modules_dim[key]
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
network.weights_sd = weights_sd
return network
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
# is it possible to apply conv_in and conv_out? # is it possible to apply conv_in and conv_out?
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = "lora_te"
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None: def __init__(
super().__init__() self,
self.multiplier = multiplier text_encoder,
unet,
multiplier=1.0,
lora_dim=4,
alpha=1,
conv_lora_dim=None,
conv_alpha=None,
modules_dim=None,
modules_alpha=None,
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim self.lora_dim = lora_dim
self.alpha = alpha self.alpha = alpha
self.conv_lora_dim = conv_lora_dim self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha self.conv_alpha = conv_alpha
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") print(f"create LoRA network from weights")
else: else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
if self.apply_to_conv2d_3x3: if self.apply_to_conv2d_3x3:
if self.conv_alpha is None: if self.conv_alpha is None:
self.conv_alpha = self.alpha self.conv_alpha = self.alpha
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances # create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
loras = [] loras = []
for name, module in root_module.named_modules(): for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules: if module.__class__.__name__ in target_replace_modules:
# TODO get block index here # TODO get block index here
for child_name, child_module in module.named_modules(): for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear" is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d: if is_linear or is_conv2d:
lora_name = prefix + '.' + name + '.' + child_name lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace('.', '_') lora_name = lora_name.replace(".", "_")
if modules_dim is not None: if modules_dim is not None:
if lora_name not in modules_dim: if lora_name not in modules_dim:
continue # no LoRA module in this weights file continue # no LoRA module in this weights file
dim = modules_dim[lora_name] dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name] alpha = modules_alpha[lora_name]
else: else:
if is_linear or is_conv2d_1x1: if is_linear or is_conv2d_1x1:
dim = self.lora_dim dim = self.lora_dim
alpha = self.alpha alpha = self.alpha
elif self.apply_to_conv2d_3x3: elif self.apply_to_conv2d_3x3:
dim = self.conv_lora_dim dim = self.conv_lora_dim
alpha = self.conv_alpha alpha = self.conv_alpha
else: else:
continue continue
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
loras.append(lora) loras.append(lora)
return loras return loras
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, self.text_encoder_loras = create_modules(
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") )
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
if modules_dim is not None or self.conv_lora_dim is not None: if modules_dim is not None or self.conv_lora_dim is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
self.weights_sd = None self.weights_sd = None
# assertion # assertion
names = set() names = set()
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name) names.add(lora.lora_name)
def set_multiplier(self, multiplier): def set_multiplier(self, multiplier):
self.multiplier = multiplier self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier lora.multiplier = self.multiplier
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open
self.weights_sd = load_file(file)
else:
self.weights_sd = torch.load(file, map_location='cpu')
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): self.weights_sd = load_file(file)
if self.weights_sd: else:
weights_has_text_encoder = weights_has_unet = False self.weights_sd = torch.load(file, map_location="cpu")
for key in self.weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
weights_has_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
weights_has_unet = True
if apply_text_encoder is None: def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
apply_text_encoder = weights_has_text_encoder if self.weights_sd:
else: weights_has_text_encoder = weights_has_unet = False
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" for key in self.weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
weights_has_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
weights_has_unet = True
if apply_unet is None: if apply_text_encoder is None:
apply_unet = weights_has_unet apply_text_encoder = weights_has_text_encoder
else: else:
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" assert (
else: apply_text_encoder == weights_has_text_encoder
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" ), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
if apply_text_encoder: if apply_unet is None:
print("enable LoRA for text encoder") apply_unet = weights_has_unet
else: else:
self.text_encoder_loras = [] assert (
apply_unet == weights_has_unet
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
else:
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
if apply_unet: if apply_text_encoder:
print("enable LoRA for U-Net") print("enable LoRA for text encoder")
else: else:
self.unet_loras = [] self.text_encoder_loras = []
for lora in self.text_encoder_loras + self.unet_loras: if apply_unet:
lora.apply_to() print("enable LoRA for U-Net")
self.add_module(lora.lora_name, lora) else:
self.unet_loras = []
if self.weights_sd: for lora in self.text_encoder_loras + self.unet_loras:
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) lora.apply_to()
info = self.load_state_dict(self.weights_sd, False) self.add_module(lora.lora_name, lora)
print(f"weights are loaded: {info}")
def enable_gradient_checkpointing(self): if self.weights_sd:
# not supported # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
pass info = self.load_state_dict(self.weights_sd, False)
print(f"weights are loaded: {info}")
def prepare_optimizer_params(self, text_encoder_lr, unet_lr): # TODO refactor to common function with apply_to
def enumerate_params(loras): def merge_to(self, text_encoder, unet, dtype, device):
params = [] assert self.weights_sd is not None, "weights are not loaded"
for lora in loras:
params.extend(lora.parameters())
return params
self.requires_grad_(True) apply_text_encoder = apply_unet = False
all_params = [] for key in self.weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
apply_unet = True
if self.text_encoder_loras: if apply_text_encoder:
param_data = {'params': enumerate_params(self.text_encoder_loras)} print("enable LoRA for text encoder")
if text_encoder_lr is not None: else:
param_data['lr'] = text_encoder_lr self.text_encoder_loras = []
all_params.append(param_data)
if self.unet_loras: if apply_unet:
param_data = {'params': enumerate_params(self.unet_loras)} print("enable LoRA for U-Net")
if unet_lr is not None: else:
param_data['lr'] = unet_lr self.unet_loras = []
all_params.append(param_data)
return all_params for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in self.weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
def prepare_grad_etc(self, text_encoder, unet): def enable_gradient_checkpointing(self):
self.requires_grad_(True) # not supported
pass
def on_epoch_start(self, text_encoder, unet): def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
self.train() def enumerate_params(loras):
params = []
for lora in loras:
params.extend(lora.parameters())
return params
def get_trainable_params(self): self.requires_grad_(True)
return self.parameters() all_params = []
def save_weights(self, file, dtype, metadata): if self.text_encoder_loras:
if metadata is not None and len(metadata) == 0: param_data = {"params": enumerate_params(self.text_encoder_loras)}
metadata = None if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
state_dict = self.state_dict() if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
if dtype is not None: return all_params
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == '.safetensors': def prepare_grad_etc(self, text_encoder, unet):
from safetensors.torch import save_file self.requires_grad_(True)
# Precalculate model hashes to save time on indexing def on_epoch_start(self, text_encoder, unet):
if metadata is None: self.train()
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata) def get_trainable_params(self):
else: return self.parameters()
torch.save(state_dict, file)
@ staticmethod def save_weights(self, file, dtype, metadata):
def set_regions(networks, image): if metadata is not None and len(metadata) == 0:
image = image.astype(np.float32) / 255.0 metadata = None
for i, network in enumerate(networks[:3]):
# NOTE: consider averaging overwrapping area
region = image[:, :, i]
if region.max() == 0:
continue
region = torch.tensor(region)
network.set_region(region)
def set_region(self, region): state_dict = self.state_dict()
for lora in self.unet_loras:
lora.set_region(region) if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
@staticmethod
def set_regions(networks, image):
image = image.astype(np.float32) / 255.0
for i, network in enumerate(networks[:3]):
# NOTE: consider averaging overwrapping area
region = image[:, :, i]
if region.max() == 0:
continue
region = torch.tensor(region)
network.set_region(region)
def set_region(self, region):
for lora in self.unet_loras:
lora.set_region(region)

View File

@ -1,4 +1,3 @@
import math import math
import argparse import argparse
import os import os
@ -9,216 +8,236 @@ import lora
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name) sd = load_file(file_name)
else: else:
sd = torch.load(file_name, map_location='cpu') sd = torch.load(file_name, map_location="cpu")
for key in list(sd.keys()): for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor: if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype) sd[key] = sd[key].to(dtype)
return sd return sd
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None: if dtype is not None:
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name) save_file(model, file_name)
else: else:
torch.save(model, file_name) torch.save(model, file_name)
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
text_encoder.to(merge_dtype) text_encoder.to(merge_dtype)
unet.to(merge_dtype) unet.to(merge_dtype)
# create module map # create module map
name_to_module = {} name_to_module = {}
for i, root_module in enumerate([text_encoder, unet]): for i, root_module in enumerate([text_encoder, unet]):
if i == 0: if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
# print(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
else: else:
# conv2d 3x3 prefix = lora.LoRANetwork.LORA_PREFIX_UNET
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) target_replace_modules = (
# print(conved.size(), weight.size(), module.stride, module.padding) lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
weight = weight + ratio * conved * scale )
module.weight = torch.nn.Parameter(weight) for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
# print(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype): def merge_lora_models(models, ratios, merge_dtype):
base_alphas = {} # alpha for merged model base_alphas = {} # alpha for merged model
base_dims = {} base_dims = {}
merged_sd = {} merged_sd = {}
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype) lora_sd = load_state_dict(model, merge_dtype)
# get alpha and dim # get alpha and dim
alphas = {} # alpha for current model alphas = {} # alpha for current model
dims = {} # dims for current model dims = {} # dims for current model
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key: if "alpha" in key:
lora_module_name = key[:key.rfind(".alpha")] lora_module_name = key[: key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy()) alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas: if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha base_alphas[lora_module_name] = alpha
elif "lora_down" in key: elif "lora_down" in key:
lora_module_name = key[:key.rfind(".lora_down")] lora_module_name = key[: key.rfind(".lora_down")]
dim = lora_sd[key].size()[0] dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim dims[lora_module_name] = dim
if lora_module_name not in base_dims: if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim base_dims[lora_module_name] = dim
for lora_module_name in dims.keys(): for lora_module_name in dims.keys():
if lora_module_name not in alphas: if lora_module_name not in alphas:
alpha = dims[lora_module_name] alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas: if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge # merge
print(f"merging...") print(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key: if "alpha" in key:
continue continue
lora_module_name = key[:key.rfind(".lora_")] lora_module_name = key[: key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name] base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name] alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio scale = math.sqrt(alpha / base_alpha) * ratio
if key in merged_sd: if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size( assert (
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" merged_sd[key].size() == lora_sd[key].size()
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
else: merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
merged_sd[key] = lora_sd[key] * scale else:
merged_sd[key] = lora_sd[key] * scale
# set alpha to sd # set alpha to sd
for lora_module_name, alpha in base_alphas.items(): for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha" key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha) merged_sd[key] = torch.tensor(alpha)
print("merged model") print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
return merged_sd return merged_sd
def merge(args): def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p): def str_to_dtype(p):
if p == 'float': if p == "float":
return torch.float return torch.float
if p == 'fp16': if p == "fp16":
return torch.float16 return torch.float16
if p == 'bf16': if p == "bf16":
return torch.bfloat16 return torch.bfloat16
return None return None
merge_dtype = str_to_dtype(args.precision) merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision) save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None: if save_dtype is None:
save_dtype = merge_dtype save_dtype = merge_dtype
if args.sd_model is not None: if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}") print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}") print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae)
args.sd_model, 0, 0, save_dtype, vae) else:
else: state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}") print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') parser.add_argument(
parser.add_argument("--save_precision", type=str, default=None, "--save_precision",
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") type=str,
parser.add_argument("--precision", type=str, default="float", default=None,
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨") choices=[None, "float", "fp16", "bf16"],
parser.add_argument("--sd_model", type=str, default=None, help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") )
parser.add_argument("--save_to", type=str, default=None, parser.add_argument(
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") "--precision",
parser.add_argument("--models", type=str, nargs='*', type=str,
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") default="float",
parser.add_argument("--ratios", type=float, nargs='*', choices=["float", "fp16", "bf16"],
help="ratios for each model / それぞれのLoRAモデルの比率") help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--sd_model",
type=str,
default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
return parser return parser
if __name__ == '__main__': if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
merge(args) merge(args)

View File

@ -25,6 +25,7 @@ timm==0.6.12
huggingface-hub==0.13.0 huggingface-hub==0.13.0
tensorflow==2.10.1 tensorflow==2.10.1
# For locon support # For locon support
lycoris_lora==0.1.4 lycoris-lora @ git+https://github.com/KohakuBlueleaf/LyCORIS.git@c3d925421209a22a60d863ffa3de0b3e7e89f047
# lycoris_lora==0.1.4
# for kohya_ss library # for kohya_ss library
. .

17
setup.bat Normal file
View File

@ -0,0 +1,17 @@
@echo off
IF NOT EXIST venv (
python -m venv venv
) ELSE (
echo venv folder already exists, skipping creation...
)
call .\venv\Scripts\activate.bat
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --use-pep517 --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config

587
setup.sh Executable file
View File

@ -0,0 +1,587 @@
#!/usr/bin/env bash
# This file will be the host environment setup file for all operating systems other than base Windows.
display_help() {
cat <<EOF
Kohya_SS Installation Script for POSIX operating systems.
Usage:
# Specifies custom branch, install directory, and git repo
setup.sh -b dev -d /workspace/kohya_ss -g https://mycustom.repo.tld/custom_fork.git
# Same as example 1, but uses long options
setup.sh --branch=dev --dir=/workspace/kohya_ss --git-repo=https://mycustom.repo.tld/custom_fork.git
# Maximum verbosity, fully automated installation in a runpod environment skipping the runpod env checks
setup.sh -vvv --skip-space-check --runpod
Options:
-b BRANCH, --branch=BRANCH Select which branch of kohya to check out on new installs.
-d DIR, --dir=DIR The full path you want kohya_ss installed to.
-g REPO, --git_repo=REPO You can optionally provide a git repo to check out for runpod installation. Useful for custom forks.
-h, --help Show this screen.
-i, --interactive Interactively configure accelerate instead of using default config file.
-n, --no-git-update Do not update kohya_ss repo. No git pull or clone operations.
-p, --public Expose public URL in runpod mode. Won't have an effect in other modes.
-r, --runpod Forces a runpod installation. Useful if detection fails for any reason.
-s, --skip-space-check Skip the 10Gb minimum storage space check.
-v, --verbose Increase verbosity levels up to 3.
EOF
}
# Checks to see if variable is set and non-empty.
# This is defined first, so we can use the function for some default variable values
env_var_exists() {
if [[ -n "${!1}" ]]; then
return 0
else
return 1
fi
}
# Need RUNPOD to have a default value before first access
RUNPOD=false
if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then
RUNPOD=true
fi
# This gets the directory the script is run from so pathing can work relative to the script where needed.
SCRIPT_DIR="$(cd -- $(dirname -- "$0") && pwd)"
# Variables defined before the getopts loop, so we have sane default values.
# Default installation locations based on OS and environment
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
if [ "$RUNPOD" = true ]; then
DIR="/workspace/kohya_ss"
elif [ -d "$SCRIPT_DIR/.git" ]; then
DIR="$SCRIPT_DIR"
elif [ -w "/opt" ]; then
DIR="/opt/kohya_ss"
elif env_var_exists HOME; then
DIR="$HOME/kohya_ss"
else
# The last fallback is simply PWD
DIR="$(PWD)"
fi
else
if [ -d "$SCRIPT_DIR/.git" ]; then
DIR="$SCRIPT_DIR"
elif env_var_exists HOME; then
DIR="$HOME/kohya_ss"
else
# The last fallback is simply PWD
DIR="$(PWD)"
fi
fi
VERBOSITY=2 #Start counting at 2 so that any increase to this will result in a minimum of file descriptor 3. You should leave this alone.
MAXVERBOSITY=6 #The highest verbosity we use / allow to be displayed. Feel free to adjust.
BRANCH="master"
GIT_REPO="https://github.com/bmaltais/kohya_ss.git"
INTERACTIVE=false
PUBLIC=false
SKIP_SPACE_CHECK=false
SKIP_GIT_UPDATE=false
while getopts ":vb:d:g:inprs-:" opt; do
# support long options: https://stackoverflow.com/a/28466267/519360
if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG
opt="${OPTARG%%=*}" # extract long option name
OPTARG="${OPTARG#$opt}" # extract long option argument (may be empty)
OPTARG="${OPTARG#=}" # if long option argument, remove assigning `=`
fi
case $opt in
b | branch) BRANCH="$OPTARG" ;;
d | dir) DIR="$OPTARG" ;;
g | git-repo) GIT_REPO="$OPTARG" ;;
i | interactive) INTERACTIVE=true ;;
n | no-git-update) SKIP_GIT_UPDATE=true ;;
p | public) PUBLIC=true ;;
r | runpod) RUNPOD=true ;;
s | skip-space-check) SKIP_SPACE_CHECK=true ;;
v) ((VERBOSITY = VERBOSITY + 1)) ;;
h) display_help && exit 0 ;;
*) display_help && exit 0 ;;
esac
done
shift $((OPTIND - 1))
# Just in case someone puts in a relative path into $DIR,
# we're going to get the absolute path of that.
if [[ "$DIR" != /* ]] && [[ "$DIR" != ~* ]]; then
DIR="$(
cd "$(dirname "$DIR")" || exit 1
pwd
)/$(basename "$DIR")"
fi
for v in $( #Start counting from 3 since 1 and 2 are standards (stdout/stderr).
seq 3 $VERBOSITY
); do
(("$v" <= "$MAXVERBOSITY")) && eval exec "$v>&2" #Don't change anything higher than the maximum verbosity allowed.
done
for v in $( #From the verbosity level one higher than requested, through the maximum;
seq $((VERBOSITY + 1)) $MAXVERBOSITY
); do
(("$v" > "2")) && eval exec "$v>/dev/null" #Redirect these to bitbucket, provided that they don't match stdout and stderr.
done
# Example of how to use the verbosity levels.
# printf "%s\n" "This message is seen at verbosity level 1 and above." >&3
# printf "%s\n" "This message is seen at verbosity level 2 and above." >&4
# printf "%s\n" "This message is seen at verbosity level 3 and above." >&5
# Debug variable dump at max verbosity
echo "BRANCH: $BRANCH
DIR: $DIR
GIT_REPO: $GIT_REPO
INTERACTIVE: $INTERACTIVE
PUBLIC: $PUBLIC
RUNPOD: $RUNPOD
SKIP_SPACE_CHECK: $SKIP_SPACE_CHECK
VERBOSITY: $VERBOSITY
Script directory is ${SCRIPT_DIR}." >&5
# This must be set after the getopts loop to account for $DIR changes.
PARENT_DIR="$(dirname "${DIR}")"
VENV_DIR="$DIR/venv"
if [ -w "$PARENT_DIR" ] && [ ! -d "$DIR" ]; then
echo "Creating install folder ${DIR}."
mkdir "$DIR"
fi
if [ ! -w "$DIR" ]; then
echo "We cannot write to ${DIR}."
echo "Please ensure the install directory is accurate and you have the correct permissions."
exit 1
fi
# Shared functions
# This checks for free space on the installation drive and returns that in Gb.
size_available() {
local folder
if [ -d "$DIR" ]; then
folder="$DIR"
elif [ -d "$PARENT_DIR" ]; then
folder="$PARENT_DIR"
elif [ -d "$(echo "$DIR" | cut -d "/" -f2)" ]; then
folder="$(echo "$DIR" | cut -d "/" -f2)"
else
echo "We are assuming a root drive install for space-checking purposes."
folder='/'
fi
local FREESPACEINKB
FREESPACEINKB="$(df -Pk "$folder" | sed 1d | grep -v used | awk '{ print $4 "\t" }')"
echo "Detected available space in Kb: $FREESPACEINKB" >&5
local FREESPACEINGB
FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024))
echo "$FREESPACEINGB"
}
# The expected usage is create_symlinks symlink target_file
create_symlinks() {
echo "Checking symlinks now."
# Next line checks for valid symlink
if [ -L "$1" ]; then
# Check if the linked file exists and points to the expected file
if [ -e "$1" ] && [ "$(readlink "$1")" == "$2" ]; then
echo "$(basename "$1") symlink looks fine. Skipping."
else
if [ -f "$2" ]; then
echo "Broken symlink detected. Recreating $(basename "$1")."
rm "$1" &&
ln -s "$2" "$1"
else
echo "$2 does not exist. Nothing to link."
fi
fi
else
echo "Linking $(basename "$1")."
ln -s "$2" "$1"
fi
}
install_python_dependencies() {
# Switch to local virtual env
echo "Switching to virtual Python environment."
if command -v python3 >/dev/null; then
python3 -m venv "$DIR/venv"
elif command -v python3.10 >/dev/null; then
python3.10 -m venv "$DIR/venv"
else
echo "Valid python3 or python3.10 binary not found."
echo "Cannot proceed with the python steps."
return 1
fi
# Activate the virtual environment
source "$DIR/venv/bin/activate"
# Updating pip if there is one
echo "Checking for pip updates before Python operations."
pip install --upgrade pip >&3
echo "Installing python dependencies. This could take a few minutes as it downloads files."
echo "If this operation ever runs too long, you can rerun this script in verbose mode to check."
case "$OSTYPE" in
"linux-gnu"*) pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 \
--extra-index-url https://download.pytorch.org/whl/cu116 >&3 &&
pip install -U -I --no-deps \
https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/downloadlinux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl >&3 ;;
"darwin"*) pip install torch==2.0.0 torchvision==0.15.1 \
-f https://download.pytorch.org/whl/cpu/torch_stable.html >&3 ;;
"cygwin")
:
;;
"msys")
:
;;
esac
if [ "$RUNPOD" = true ]; then
echo "Installing tenssort."
pip install tensorrt >&3
fi
# DEBUG ONLY (Update this version number to whatever PyCharm recommends)
# pip install pydevd-pycharm~=223.8836.43
#This will copy our requirements.txt file out, make the khoya_ss lib a dynamic location then cleanup.
echo "Copying $DIR/requirements.txt to /tmp/requirements_tmp.txt" >&3
echo "Replacing the . for lib to our DIR variable in tmp/requirements_tmp.txt." >&3
awk -v dir="$DIR" '/#.*kohya_ss.*library/{print; getline; sub(/^\.$/, dir)}1' "$DIR/requirements.txt" >/tmp/requirements_tmp.txt
if [ $VERBOSITY == 2 ]; then
python -m pip install --quiet --use-pep517 --upgrade -r /tmp/requirements_tmp.txt >&3
else
python -m pip install --use-pep517 --upgrade -r /tmp/requirements_tmp.txt >&3
fi
echo "Removing the temp requirements file."
if [ -f /tmp/requirements_tmp.txt ]; then
rm /tmp/requirements_tmp.txt
fi
if [ -n "$VIRTUAL_ENV" ]; then
if command -v deactivate >/dev/null; then
echo "Exiting Python virtual environment."
deactivate
else
echo "deactivate command not found. Could still be in the Python virtual environment."
fi
fi
}
# Attempt to non-interactively install a default accelerate config file unless specified otherwise.
# Documentation for order of precedence locations for configuration file for automated installation:
# https://huggingface.co/docs/accelerate/basic_tutorials/launch#custom-configurations
configure_accelerate() {
echo "Source accelerate config location: $DIR/config_files/accelerate/default_config.yaml" >&3
if [ "$INTERACTIVE" = true ]; then
accelerate config
else
if env_var_exists HF_HOME; then
if [ ! -f "$HF_HOME/accelerate/default_config.yaml" ]; then
mkdir -p "$HF_HOME/accelerate/" &&
echo "Target accelerate config location: $HF_HOME/accelerate/default_config.yaml" >&3
cp "$DIR/config_files/accelerate/default_config.yaml" "$HF_HOME/accelerate/default_config.yaml" &&
echo "Copied accelerate config file to: $HF_HOME/accelerate/default_config.yaml"
fi
elif env_var_exists XDG_CACHE_HOME; then
if [ ! -f "$XDG_CACHE_HOME/huggingface/accelerate" ]; then
mkdir -p "$XDG_CACHE_HOME/huggingface/accelerate" &&
echo "Target accelerate config location: $XDG_CACHE_HOME/accelerate/default_config.yaml" >&3
cp "$DIR/config_files/accelerate/default_config.yaml" "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" &&
echo "Copied accelerate config file to: $XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml"
fi
elif env_var_exists HOME; then
if [ ! -f "$HOME/.cache/huggingface/accelerate" ]; then
mkdir -p "$HOME/.cache/huggingface/accelerate" &&
echo "Target accelerate config location: $HOME/accelerate/default_config.yaml" >&3
cp "$DIR/config_files/accelerate/default_config.yaml" "$HOME/.cache/huggingface/accelerate/default_config.yaml" &&
echo "Copying accelerate config file to: $HOME/.cache/huggingface/accelerate/default_config.yaml"
fi
else
echo "Could not place the accelerate configuration file. Please configure manually."
sleep 2
accelerate config
fi
fi
}
# Offer a warning and opportunity to cancel the installation if < 10Gb of Free Space detected
check_storage_space() {
if [ "$SKIP_SPACE_CHECK" = false ]; then
if [ "$(size_available)" -lt 10 ]; then
echo "You have less than 10Gb of free space. This installation may fail."
MSGTIMEOUT=10 # In seconds
MESSAGE="Continuing in..."
echo "Press control-c to cancel the installation."
for ((i = MSGTIMEOUT; i >= 0; i--)); do
printf "\r${MESSAGE} %ss. " "${i}"
sleep 1
done
fi
fi
}
# These are the git operations that will run to update or clone the repo
update_kohya_ss() {
if [ "$SKIP_GIT_UPDATE" = false ]; then
if command -v git >/dev/null; then
# First, we make sure there are no changes that need to be made in git, so no work is lost.
if [ "$(git -C "$DIR" status --porcelain=v1 2>/dev/null | wc -l)" -gt 0 ] &&
echo "These files need to be committed or discarded: " >&4 &&
git -C "$DIR" status >&4; then
echo "There are changes that need to be committed or discarded in the repo in $DIR."
echo "Commit those changes or run this script with -n to skip git operations entirely."
exit 1
fi
echo "Attempting to clone $GIT_REPO."
if [ ! -d "$DIR/.git" ]; then
echo "Cloning and switching to $GIT_REPO:$BRANCH" >&4
git -C "$PARENT_DIR" clone -b "$BRANCH" "$GIT_REPO" "$(basename "$DIR")" >&3
git -C "$DIR" switch "$BRANCH" >&4
else
echo "git repo detected. Attempting to update repository instead."
echo "Updating: $GIT_REPO"
git -C "$DIR" pull "$GIT_REPO" "$BRANCH" >&3
if ! git -C "$DIR" switch "$BRANCH" >&4; then
echo "Branch $BRANCH did not exist. Creating it." >&4
git -C "$DIR" switch -c "$BRANCH" >&4
fi
fi
else
echo "You need to install git."
echo "Rerun this after installing git or run this script with -n to skip the git operations."
fi
else
echo "Skipping git operations."
fi
}
# Start OS-specific detection and work
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
# Check if root or sudo
root=false
if [ "$EUID" = 0 ]; then
root=true
elif command -v id >/dev/null && [ "$(id -u)" = 0 ]; then
root=true
elif [ "$UID" = 0 ]; then
root=true
fi
get_distro_name() {
local line
if [ -f /etc/os-release ]; then
# We search for the line starting with ID=
# Then we remove the ID= prefix to get the name itself
line="$(grep -Ei '^ID=' /etc/os-release)"
echo "Raw detected os-release distro line: $line" >&5
line=${line##*=}
echo "$line"
return 0
elif command -v python >/dev/null; then
line="$(python -mplatform)"
echo "$line"
return 0
elif command -v python3 >/dev/null; then
line="$(python3 -mplatform)"
echo "$line"
return 0
else
line="None"
echo "$line"
return 1
fi
}
# We search for the line starting with ID_LIKE=
# Then we remove the ID_LIKE= prefix to get the name itself
# This is the "type" of distro. For example, Ubuntu returns "debian".
get_distro_family() {
local line
if [ -f /etc/os-release ]; then
if grep -Eiq '^ID_LIKE=' /etc/os-release >/dev/null; then
line="$(grep -Ei '^ID_LIKE=' /etc/os-release)"
echo "Raw detected os-release distro family line: $line" >&5
line=${line##*=}
echo "$line"
return 0
else
line="None"
echo "$line"
return 1
fi
else
line="None"
echo "$line"
return 1
fi
}
check_storage_space
update_kohya_ss
distro=get_distro_name
family=get_distro_family
echo "Raw detected distro string: $distro" >&4
echo "Raw detected distro family string: $family" >&4
echo "Installing Python TK if not found on the system."
if "$distro" | grep -qi "Ubuntu" || "$family" | grep -qi "Ubuntu"; then
echo "Ubuntu detected."
if [ $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") = 0 ]; then
if [ "$root" = true ]; then
apt update -y >&3 && apt install -y python3-tk >&3
else
echo "This script needs to be run as root or via sudo to install packages."
exit 1
fi
else
echo "Python TK found! Skipping install!"
fi
elif "$distro" | grep -Eqi "Fedora|CentOS|Redhat"; then
echo "Redhat or Redhat base detected."
if ! rpm -qa | grep -qi python3-tkinter; then
if [ "$root" = true ]; then
dnf install python3-tkinter -y >&3
else
echo "This script needs to be run as root or via sudo to install packages."
exit 1
fi
fi
elif "$distro" | grep -Eqi "arch" || "$family" | grep -qi "arch"; then
echo "Arch Linux or Arch base detected."
if ! pacman -Qi tk >/dev/null; then
if [ "$root" = true ]; then
pacman --noconfirm -S tk >&3
else
echo "This script needs to be run as root or via sudo to install packages."
exit 1
fi
fi
elif "$distro" | grep -Eqi "opensuse" || "$family" | grep -qi "opensuse"; then
echo "OpenSUSE detected."
if ! rpm -qa | grep -qi python-tk; then
if [ "$root" = true ]; then
zypper install -y python-tk >&3
else
echo "This script needs to be run as root or via sudo to install packages."
exit 1
fi
fi
elif [ "$distro" = "None" ] || [ "$family" = "None" ]; then
if [ "$distro" = "None" ]; then
echo "We could not detect your distribution of Linux. Please file a bug report on github with the contents of your /etc/os-release file."
fi
if [ "$family" = "None" ]; then
echo "We could not detect the family of your Linux distribution. Please file a bug report on github with the contents of your /etc/os-release file."
fi
fi
install_python_dependencies
# We need just a little bit more setup for non-interactive environments
if [ "$RUNPOD" = true ]; then
# Symlink paths
libnvinfer_plugin_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7"
libnvinfer_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.7"
libcudart_symlink="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0"
#Target file paths
libnvinfer_plugin_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.8"
libnvinfer_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.8"
libcudart_target="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12"
echo "Checking symlinks now."
create_symlinks "$libnvinfer_plugin_symlink" "$libnvinfer_plugin_target"
create_symlinks "$libnvinfer_symlink" "$libnvinfer_target"
create_symlinks "$libcudart_symlink" "$libcudart_target"
if [ -d "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" ]; then
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${VENV_DIR}/lib/python3.10/site-packages/tensorrt/"
else
echo "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/ not found; not linking library."
fi
if [ -d "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" ]; then
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${VENV_DIR}/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/"
else
echo "${VENV_DIR}/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/ not found; not linking library."
fi
configure_accelerate
# This is a non-interactive environment, so just directly call gui.sh after all setup steps are complete.
if command -v bash >/dev/null; then
if [ "$PUBLIC" = false ]; then
bash "$DIR"/gui.sh
else
bash "$DIR"/gui.sh --share
fi
else
# This shouldn't happen, but we're going to try to help.
if [ "$PUBLIC" = false ]; then
sh "$DIR"/gui.sh
else
sh "$DIR"/gui.sh --share
fi
fi
fi
echo -e "Setup finished! Run \e[0;92m./gui.sh\e[0m to start."
echo "Please note if you'd like to expose your public server you need to run ./gui.sh --share"
elif [[ "$OSTYPE" == "darwin"* ]]; then
# The initial setup script to prep the environment on macOS
# xformers has been omitted as that is for Nvidia GPUs only
if ! command -v brew >/dev/null; then
echo "Please install homebrew first. This is a requirement for the remaining setup."
echo "You can find that here: https://brew.sh"
#shellcheck disable=SC2016
echo 'The "brew" command should be in $PATH to be detected.'
exit 1
fi
check_storage_space
# Install base python packages
echo "Installing Python 3.10 if not found."
if ! brew ls --versions python@3.10 >/dev/null; then
echo "Installing Python 3.10."
brew install python@3.10 >&3
else
echo "Python 3.10 found!"
fi
echo "Installing Python-TK 3.10 if not found."
if ! brew ls --versions python-tk@3.10 >/dev/null; then
echo "Installing Python TK 3.10."
brew install python-tk@3.10 >&3
else
echo "Python Tkinter 3.10 found!"
fi
update_kohya_ss
if ! install_python_dependencies; then
echo "You may need to install Python. The command for this is brew install python@3.10."
fi
configure_accelerate
echo -e "Setup finished! Run ./gui.sh to start."
elif [[ "$OSTYPE" == "cygwin" ]]; then
# Cygwin is a standalone suite of Linux utilies on Windows
echo "This hasn't been validated on cygwin yet."
elif [[ "$OSTYPE" == "msys" ]]; then
# MinGW has the msys environment which is a standalone suite of Linux utilies on Windows
# "git bash" on Windows may also be detected as msys.
echo "This hasn't been validated in msys (mingw) on Windows yet."
fi

80
tools/merge_lycoris.py Normal file
View File

@ -0,0 +1,80 @@
import os
import sys
import argparse
import torch
from lycoris.utils import merge_loha, merge_locon
from lycoris.kohya_model_utils import (
load_models_from_stable_diffusion_checkpoint,
save_stable_diffusion_checkpoint,
load_file
)
import gradio as gr
def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, weight):
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
if lycoris_model.rsplit('.', 1)[-1] == 'safetensors':
lyco = load_file(lycoris_model)
else:
lyco = torch.load(lycoris_model)
algo = None
for key in lyco:
if 'hada' in key:
algo = 'loha'
break
elif 'lora_up' in key:
algo = 'lora'
break
else:
raise NotImplementedError('Cannot find the algo for this lycoris model file.')
dtype_str = dtype.replace('fp', 'float').replace('bf', 'bfloat')
dtype = {
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
'bfloat': torch.bfloat16,
'bfloat16': torch.bfloat16,
}.get(dtype_str, None)
if dtype is None:
raise ValueError(f'Cannot Find the dtype "{dtype}"')
if algo == 'loha':
merge_loha(base, lyco, weight, device)
elif algo == 'lora':
merge_locon(base, lyco, weight, device)
save_stable_diffusion_checkpoint(
is_v2, output_name,
base[0], base[2],
None, 0, 0, dtype,
base[1]
)
return output_name
def main():
iface = gr.Interface(
fn=merge_models,
inputs=[
gr.inputs.Textbox(label="Base Model Path"),
gr.inputs.Textbox(label="Lycoris Model Path"),
gr.inputs.Textbox(label="Output Model Path", default='./out.pt'),
gr.inputs.Checkbox(label="Is base model SD V2?", default=False),
gr.inputs.Textbox(label="Device", default='cpu'),
gr.inputs.Dropdown(choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'], label="Dtype", default='float'),
gr.inputs.Number(label="Weight", default=1.0)
],
outputs=gr.outputs.Textbox(label="Merged Model Path"),
title="Model Merger",
description="Merge Lycoris and Stable Diffusion models",
)
iface.launch()
if __name__ == '__main__':
main()

View File

@ -25,7 +25,19 @@ for requirement in requirements:
try: try:
pkg_resources.require(requirement) pkg_resources.require(requirement)
except pkg_resources.DistributionNotFound: except pkg_resources.DistributionNotFound:
missing_requirements.append(requirement) # Check if the requirement contains a VCS URL
if "@" in requirement:
# If it does, split the requirement into two parts: the package name and the VCS URL
package_name, vcs_url = requirement.split("@", 1)
# Use pip to install the package from the VCS URL
os.system(f"pip install -e {vcs_url}")
# Try to require the package again
try:
pkg_resources.require(package_name)
except pkg_resources.DistributionNotFound:
missing_requirements.append(requirement)
else:
missing_requirements.append(requirement)
except pkg_resources.VersionConflict as e: except pkg_resources.VersionConflict as e:
wrong_version_requirements.append((requirement, str(e.req), e.dist.version)) wrong_version_requirements.append((requirement, str(e.req), e.dist.version))

View File

@ -1,30 +1,31 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse import argparse
import gc import gc
import importlib
import json
import math import math
import os import os
import random import random
import time import time
import json
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
import library.config_ml_util as config_util
import library.custom_train_functions as custom_train_functions
import library.train_util as train_util import library.train_util as train_util
from library.config_ml_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.custom_train_functions import apply_snr_weight
from library.train_util import ( from library.train_util import (
DreamBoothDataset, DreamBoothDataset,
) )
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
@ -126,12 +127,25 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) for pi in range(accelerator.state.num_processes):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
# work on low-ram device
if args.lowram:
text_encoder.to("cuda")
unet.to("cuda")
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@ -188,7 +202,7 @@ def train(args):
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
@ -555,9 +569,9 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

View File

@ -0,0 +1,644 @@
import importlib
import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
def train(args):
if args.output_name is None:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
print(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
cache_latents = args.cache_latents
if args.seed is not None:
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# Convert the init_word to token_id
if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
)
else:
init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == args.num_vectors_per_token
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
token_strings_XTI = []
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
for layer_name in XTI_layers:
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
print(f"tokens are added (XTI): {token_ids_XTI}")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_ids is not None:
for i, token_id in enumerate(token_ids_XTI):
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids_XTI, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print("use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
for tmpl in templates:
captions.append(tmpl.format(replace_to))
train_dataset_group.add_replacement("", captions)
if args.num_vectors_per_token > 1:
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
else:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
else:
unet.eval()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = torch.stack(
[
train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
for s in torch.split(input_ids, 1, dim=1)
]
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# TODO: fix sample_images
# train_util.sample_images(
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
# )
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
# TODO: fix sample_images
# train_util.sample_images(
# accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
# )
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
print("model saved.")
def save_weights(file, updated_embs, save_dtype):
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
updated_embs = updated_embs.chunk(16)
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
state_dict = {}
for i, layer_name in enumerate(XTI_layers):
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
# if save_dtype is not None:
# for key in list(state_dict.keys()):
# v = state_dict[key]
# v = v.detach().clone().to("cpu").to(save_dtype)
# state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file)
else:
torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
data = load_file(file)
else:
raise ValueError(f"NOT XTI: {file}")
if len(data.values()) != 16:
raise ValueError(f"NOT XTI: {file}")
emb = torch.concat([x for x in data.values()])
return emb
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="pt",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
parser.add_argument(
"--token_string",
type=str,
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--use_object_template",
action="store_true",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
)
parser.add_argument(
"--use_style_template",
action="store_true",
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

16
upgrade.bat Normal file
View File

@ -0,0 +1,16 @@
@echo off
:: Check if there are any changes that need to be committed
git status --short
if %errorlevel%==1 (
echo There are changes that need to be committed. Please stash or undo your changes before running this script.
exit
)
:: Pull the latest changes from the remote repository
git pull
:: Activate the virtual environment
call .\venv\Scripts\activate.baT
:: Upgrade the required packages
pip install --upgrade -r requirements.txt