v20.4.0:
Add support for `network_alpha` under the Training tab and support for `--training_comment` under the Folders tab.
This commit is contained in:
parent
fcad6bfd98
commit
2ca17f69dd
18
README.md
18
README.md
@ -116,7 +116,23 @@ Once you have created the LoRA network you can generate images via auto1111 by i
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
* 2023/01/16 (v20.3.0)
|
* 2023/01/22 (v20.4.0):
|
||||||
|
- Add support for `network_alpha` under the Training tab and support for `--training_comment` under the Folders tab.
|
||||||
|
- Add ``--network_alpha`` option to specify ``alpha`` value to prevent underflows for stable training. Thanks to CCRcmcpe!
|
||||||
|
- Details of the issue are described in https://github.com/kohya-ss/sd-webui-additional-networks/issues/49 .
|
||||||
|
- The default value is ``1``, scale ``1 / rank (or dimension)``. Set same value as ``network_dim`` for same behavior to old version.
|
||||||
|
- LoRA with a large dimension (rank) seems to require a higher learning rate with ``alpha=1`` (e.g. 1e-3 for 128-dim, still investigating).
|
||||||
|
- For generating images in Web UI, __the latest version of the extension ``sd-webui-additional-networks`` (v0.3.0 or later) is required for the models trained with this release or later.__
|
||||||
|
- Add logging for the learning rate for U-Net and Text Encoder independently, and for running average epoch loss. Thanks to mgz-dev!
|
||||||
|
- Add more metadata such as dataset/reg image dirs, session ID, output name etc... See https://github.com/kohya-ss/sd-scripts/pull/77 for details. Thanks to space-nuko!
|
||||||
|
- __Now the metadata includes the folder name (the basename of the folder contains image files, not fullpath).__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
|
||||||
|
- Add ``--training_comment`` option. You can specify an arbitrary string and refer to it by the extension.
|
||||||
|
|
||||||
|
It seems that the Stable Diffusion web UI now supports image generation using the LoRA model learned in this repository.
|
||||||
|
|
||||||
|
Note: At this time, it appears that models learned with version 0.4.0 are not supported. If you want to use the generation function of the web UI, please continue to use version 0.3.2. Also, it seems that LoRA models for SD2.x are not supported.
|
||||||
|
|
||||||
|
* 2023/01/16 (v20.3.0):
|
||||||
- Fix a part of LoRA modules are not trained when ``gradient_checkpointing`` is enabled.
|
- Fix a part of LoRA modules are not trained when ``gradient_checkpointing`` is enabled.
|
||||||
- Add ``--save_last_n_epochs_state`` option. You can specify how many state folders to keep, apart from how many models to keep. Thanks to shirayu!
|
- Add ``--save_last_n_epochs_state`` option. You can specify how many state folders to keep, apart from how many models to keep. Thanks to shirayu!
|
||||||
- Fix Text Encoder training stops at ``max_train_steps`` even if ``max_train_epochs`` is set in `train_db.py``.
|
- Fix Text Encoder training stops at ``max_train_steps`` even if ``max_train_epochs`` is set in `train_db.py``.
|
||||||
|
@ -1981,7 +1981,6 @@ def main(args):
|
|||||||
imported_module = importlib.import_module(network_module)
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
|
|
||||||
|
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
if args.network_args and i < len(args.network_args):
|
if args.network_args and i < len(args.network_args):
|
||||||
@ -1992,22 +1991,22 @@ def main(args):
|
|||||||
key, value = net_arg.split("=")
|
key, value = net_arg.split("=")
|
||||||
net_kwargs[key] = value
|
net_kwargs[key] = value
|
||||||
|
|
||||||
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
|
|
||||||
if network is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.network_weights and i < len(args.network_weights):
|
if args.network_weights and i < len(args.network_weights):
|
||||||
network_weight = args.network_weights[i]
|
network_weight = args.network_weights[i]
|
||||||
print("load network weights from:", network_weight)
|
print("load network weights from:", network_weight)
|
||||||
|
|
||||||
if os.path.splitext(network_weight)[1] == '.safetensors':
|
if model_util.is_safetensors(network_weight):
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
with safe_open(network_weight, framework="pt") as f:
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
print(f"metadata for: {network_weight}: {metadata}")
|
print(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
network.load_weights(network_weight)
|
network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("No weight. Weight is required.")
|
||||||
|
if network is None:
|
||||||
|
return
|
||||||
|
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
|
|
||||||
@ -2518,16 +2517,14 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
|
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
|
||||||
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
|
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
|
||||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||||
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
||||||
parser.add_argument("--opt_channels_last", action='store_true',
|
parser.add_argument("--opt_channels_last", action='store_true',
|
||||||
help='set channels last option to model / モデルにchannles lastを指定し最適化する')
|
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
||||||
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
||||||
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||||
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||||
help='Hypernetwork weights to load / Hypernetworkの重み')
|
help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||||
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
|
|
||||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
|
||||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||||
|
@ -11,6 +11,7 @@ import glob
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import hashlib
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@ -79,6 +80,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.debug_dataset = debug_dataset
|
self.debug_dataset = debug_dataset
|
||||||
self.random_crop = random_crop
|
self.random_crop = random_crop
|
||||||
self.token_padding_disabled = False
|
self.token_padding_disabled = False
|
||||||
|
self.dataset_dirs_info = {}
|
||||||
|
self.reg_dataset_dirs_info = {}
|
||||||
|
self.enable_bucket = False
|
||||||
|
self.min_bucket_reso = None
|
||||||
|
self.max_bucket_reso = None
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
|
|
||||||
@ -463,6 +469,8 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||||
|
self.min_bucket_reso = min_bucket_reso
|
||||||
|
self.max_bucket_reso = max_bucket_reso
|
||||||
else:
|
else:
|
||||||
self.bucket_resos = [(self.width, self.height)]
|
self.bucket_resos = [(self.width, self.height)]
|
||||||
self.bucket_aspect_ratios = [self.width / self.height]
|
self.bucket_aspect_ratios = [self.width / self.height]
|
||||||
@ -523,6 +531,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
||||||
self.register_image(info)
|
self.register_image(info)
|
||||||
|
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
print(f"{num_train_images} train images with repeating.")
|
print(f"{num_train_images} train images with repeating.")
|
||||||
self.num_train_images = num_train_images
|
self.num_train_images = num_train_images
|
||||||
|
|
||||||
@ -539,6 +548,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
||||||
reg_infos.append(info)
|
reg_infos.append(info)
|
||||||
|
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
|
|
||||||
print(f"{num_reg_images} reg images.")
|
print(f"{num_reg_images} reg images.")
|
||||||
if num_train_images < num_reg_images:
|
if num_train_images < num_reg_images:
|
||||||
@ -611,6 +621,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.num_train_images = len(metadata) * dataset_repeats
|
self.num_train_images = len(metadata) * dataset_repeats
|
||||||
self.num_reg_images = 0
|
self.num_reg_images = 0
|
||||||
|
|
||||||
|
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||||
|
|
||||||
# check existence of all npz files
|
# check existence of all npz files
|
||||||
if not self.color_aug:
|
if not self.color_aug:
|
||||||
npz_any = False
|
npz_any = False
|
||||||
@ -653,6 +665,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||||
|
self.min_bucket_reso = min_bucket_reso
|
||||||
|
self.max_bucket_reso = max_bucket_reso
|
||||||
else:
|
else:
|
||||||
self.bucket_resos = [(self.width, self.height)]
|
self.bucket_resos = [(self.width, self.height)]
|
||||||
self.bucket_aspect_ratios = [self.width / self.height]
|
self.bucket_aspect_ratios = [self.width / self.height]
|
||||||
@ -665,6 +679,9 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.bucket_resos.sort()
|
self.bucket_resos.sort()
|
||||||
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
|
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
|
||||||
|
|
||||||
|
self.min_bucket_reso = min([min(reso) for reso in resos])
|
||||||
|
self.max_bucket_reso = max([max(reso) for reso in resos])
|
||||||
|
|
||||||
def image_key_to_npz_file(self, image_key):
|
def image_key_to_npz_file(self, image_key):
|
||||||
base_name = os.path.splitext(image_key)[0]
|
base_name = os.path.splitext(image_key)[0]
|
||||||
npz_file_norm = base_name + '.npz'
|
npz_file_norm = base_name + '.npz'
|
||||||
@ -749,9 +766,9 @@ def default(val, d):
|
|||||||
|
|
||||||
|
|
||||||
def model_hash(filename):
|
def model_hash(filename):
|
||||||
|
"""Old model hash used by stable-diffusion-webui"""
|
||||||
try:
|
try:
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
import hashlib
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
|
|
||||||
file.seek(0x100000)
|
file.seek(0x100000)
|
||||||
@ -761,6 +778,18 @@ def model_hash(filename):
|
|||||||
return 'NOFILE'
|
return 'NOFILE'
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_sha256(filename):
|
||||||
|
"""New model hash used by stable-diffusion-webui"""
|
||||||
|
hash_sha256 = hashlib.sha256()
|
||||||
|
blksize = 1024 * 1024
|
||||||
|
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(blksize), b""):
|
||||||
|
hash_sha256.update(chunk)
|
||||||
|
|
||||||
|
return hash_sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
# flash attention forwards and backwards
|
# flash attention forwards and backwards
|
||||||
|
|
||||||
# https://arxiv.org/abs/2205.14135
|
# https://arxiv.org/abs/2205.14135
|
||||||
@ -1029,7 +1058,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||||
parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
||||||
|
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
||||||
parser.add_argument("--save_state", action="store_true",
|
parser.add_argument("--save_state", action="store_true",
|
||||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||||
@ -1048,8 +1078,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
|
|
||||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||||
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
parser.add_argument("--max_train_epochs", type=int, default=None,
|
||||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||||
|
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
||||||
|
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||||
|
28
lora_gui.py
28
lora_gui.py
@ -88,6 +88,8 @@ def save_configuration(
|
|||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
|
network_alpha,
|
||||||
|
training_comment,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -175,6 +177,8 @@ def open_configuration(
|
|||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
|
network_alpha,
|
||||||
|
training_comment,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -246,6 +250,8 @@ def train_model(
|
|||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
|
network_alpha,
|
||||||
|
training_comment,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -358,6 +364,9 @@ def train_model(
|
|||||||
run_cmd += f' --resolution={max_resolution}'
|
run_cmd += f' --resolution={max_resolution}'
|
||||||
run_cmd += f' --output_dir="{output_dir}"'
|
run_cmd += f' --output_dir="{output_dir}"'
|
||||||
run_cmd += f' --logging_dir="{logging_dir}"'
|
run_cmd += f' --logging_dir="{logging_dir}"'
|
||||||
|
run_cmd += f' --network_alpha="{network_alpha}"'
|
||||||
|
if not training_comment == '':
|
||||||
|
run_cmd += f' --training_comment="{training_comment}"'
|
||||||
if not stop_text_encoder_training == 0:
|
if not stop_text_encoder_training == 0:
|
||||||
run_cmd += (
|
run_cmd += (
|
||||||
f' --stop_text_encoder_training={stop_text_encoder_training}'
|
f' --stop_text_encoder_training={stop_text_encoder_training}'
|
||||||
@ -518,10 +527,15 @@ def lora_tab(
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
output_name = gr.Textbox(
|
output_name = gr.Textbox(
|
||||||
label='Model output name',
|
label='Model output name',
|
||||||
placeholder='Name of the model to output',
|
placeholder='(Name of the model to output)',
|
||||||
value='last',
|
value='last',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
|
training_comment = gr.Textbox(
|
||||||
|
label='Training comment',
|
||||||
|
placeholder='(Optional) Add training comment to be included in metadata',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
train_data_dir.change(
|
train_data_dir.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[train_data_dir],
|
inputs=[train_data_dir],
|
||||||
@ -588,11 +602,19 @@ def lora_tab(
|
|||||||
network_dim = gr.Slider(
|
network_dim = gr.Slider(
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=128,
|
maximum=128,
|
||||||
label='Network Dimension',
|
label='Network Rank (Dimension)',
|
||||||
value=8,
|
value=8,
|
||||||
step=1,
|
step=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
|
network_alpha = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=128,
|
||||||
|
label='Network Alpha',
|
||||||
|
value=1,
|
||||||
|
step=1,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_resolution = gr.Textbox(
|
max_resolution = gr.Textbox(
|
||||||
label='Max resolution',
|
label='Max resolution',
|
||||||
@ -703,6 +725,8 @@ def lora_tab(
|
|||||||
max_token_length,
|
max_token_length,
|
||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
|
network_alpha,
|
||||||
|
training_comment,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -15,12 +15,13 @@ def main(file):
|
|||||||
|
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if 'lora_up' in key:
|
if 'lora_up' in key or 'lora_down' in key:
|
||||||
values.append((key, sd[key]))
|
values.append((key, sd[key]))
|
||||||
print(f"number of LoRA-up modules: {len(values)}")
|
print(f"number of LoRA modules: {len(values)}")
|
||||||
|
|
||||||
for key, value in values:
|
for key, value in values:
|
||||||
print(f"{key},{torch.mean(torch.abs(value))}")
|
value = value.to(torch.float32)
|
||||||
|
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -44,9 +44,9 @@ def svd(args):
|
|||||||
print(f"loading SD model : {args.model_tuned}")
|
print(f"loading SD model : {args.model_tuned}")
|
||||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
||||||
|
|
||||||
# create LoRA network to extract weights
|
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||||
lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o)
|
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
|
||||||
lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t)
|
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
|
||||||
assert len(lora_network_o.text_encoder_loras) == len(
|
assert len(lora_network_o.text_encoder_loras) == len(
|
||||||
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||||
|
|
||||||
@ -77,10 +77,10 @@ def svd(args):
|
|||||||
module_t = lora_t.org_module
|
module_t = lora_t.org_module
|
||||||
diff = module_t.weight - module_o.weight
|
diff = module_t.weight - module_o.weight
|
||||||
diff = diff.float()
|
diff = diff.float()
|
||||||
|
|
||||||
if args.device:
|
if args.device:
|
||||||
diff = diff.to(args.device)
|
diff = diff.to(args.device)
|
||||||
|
|
||||||
diffs[lora_name] = diff
|
diffs[lora_name] = diff
|
||||||
|
|
||||||
# make LoRA with svd
|
# make LoRA with svd
|
||||||
@ -116,6 +116,9 @@ def svd(args):
|
|||||||
print(f"LoRA has {len(lora_sd)} weights.")
|
print(f"LoRA has {len(lora_sd)} weights.")
|
||||||
|
|
||||||
for key in list(lora_sd.keys()):
|
for key in list(lora_sd.keys()):
|
||||||
|
if "alpha" in key:
|
||||||
|
continue
|
||||||
|
|
||||||
lora_name = key.split('.')[0]
|
lora_name = key.split('.')[0]
|
||||||
i = 0 if "lora_up" in key else 1
|
i = 0 if "lora_up" in key else 1
|
||||||
|
|
||||||
@ -124,7 +127,7 @@ def svd(args):
|
|||||||
if len(lora_sd[key].size()) == 4:
|
if len(lora_sd[key].size()) == 4:
|
||||||
weights = weights.unsqueeze(2).unsqueeze(3)
|
weights = weights.unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
assert weights.size() == lora_sd[key].size()
|
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
||||||
lora_sd[key] = weights
|
lora_sd[key] = weights
|
||||||
|
|
||||||
# load state dict to LoRA and save it
|
# load state dict to LoRA and save it
|
||||||
@ -135,7 +138,10 @@ def svd(args):
|
|||||||
if dir_name and not os.path.exists(dir_name):
|
if dir_name and not os.path.exists(dir_name):
|
||||||
os.makedirs(dir_name, exist_ok=True)
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
|
||||||
lora_network_o.save_weights(args.save_to, save_dtype, {})
|
# minimum metadata
|
||||||
|
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||||
|
|
||||||
|
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
||||||
print(f"LoRA weights are saved to: {args.save_to}")
|
print(f"LoRA weights are saved to: {args.save_to}")
|
||||||
|
|
||||||
|
|
||||||
@ -151,8 +157,8 @@ if __name__ == '__main__':
|
|||||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
|
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
|
||||||
parser.add_argument("--save_to", type=str, default=None,
|
parser.add_argument("--save_to", type=str, default=None,
|
||||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||||
parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数(デフォルト4)")
|
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||||
parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
svd(args)
|
svd(args)
|
||||||
|
@ -13,9 +13,11 @@ class LoRAModule(torch.nn.Module):
|
|||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4):
|
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||||
|
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lora_name = lora_name
|
self.lora_name = lora_name
|
||||||
|
self.lora_dim = lora_dim
|
||||||
|
|
||||||
if org_module.__class__.__name__ == 'Conv2d':
|
if org_module.__class__.__name__ == 'Conv2d':
|
||||||
in_dim = org_module.in_channels
|
in_dim = org_module.in_channels
|
||||||
@ -28,6 +30,12 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
||||||
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
||||||
|
|
||||||
|
if type(alpha) == torch.Tensor:
|
||||||
|
alpha = alpha.detach().numpy()
|
||||||
|
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||||
|
self.scale = alpha / self.lora_dim
|
||||||
|
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||||
|
|
||||||
# same as microsoft's
|
# same as microsoft's
|
||||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||||
torch.nn.init.zeros_(self.lora_up.weight)
|
torch.nn.init.zeros_(self.lora_up.weight)
|
||||||
@ -41,13 +49,37 @@ class LoRAModule(torch.nn.Module):
|
|||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
network_dim = 4 # default
|
network_dim = 4 # default
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim)
|
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
||||||
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
|
from safetensors.torch import load_file, safe_open
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location='cpu')
|
||||||
|
|
||||||
|
# get dim (rank)
|
||||||
|
network_alpha = None
|
||||||
|
network_dim = None
|
||||||
|
for key, value in weights_sd.items():
|
||||||
|
if network_alpha is None and 'alpha' in key:
|
||||||
|
network_alpha = value
|
||||||
|
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
||||||
|
network_dim = value.size()[0]
|
||||||
|
|
||||||
|
if network_alpha is None:
|
||||||
|
network_alpha = network_dim
|
||||||
|
|
||||||
|
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
||||||
|
network.weights_sd = weights_sd
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
@ -57,10 +89,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
LORA_PREFIX_UNET = 'lora_unet'
|
LORA_PREFIX_UNET = 'lora_unet'
|
||||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||||
|
|
||||||
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
|
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
||||||
@ -71,7 +104,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
||||||
lora_name = prefix + '.' + name + '.' + child_name
|
lora_name = prefix + '.' + name + '.' + child_name
|
||||||
lora_name = lora_name.replace('.', '_')
|
lora_name = lora_name.replace('.', '_')
|
||||||
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim)
|
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
@ -149,21 +182,21 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
params = []
|
all_params = []
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
||||||
if text_encoder_lr is not None:
|
if text_encoder_lr is not None:
|
||||||
param_data['lr'] = text_encoder_lr
|
param_data['lr'] = text_encoder_lr
|
||||||
params.append(param_data)
|
all_params.append(param_data)
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
param_data = {'params': enumerate_params(self.unet_loras)}
|
param_data = {'params': enumerate_params(self.unet_loras)}
|
||||||
if unet_lr is not None:
|
if unet_lr is not None:
|
||||||
param_data['lr'] = unet_lr
|
param_data['lr'] = unet_lr
|
||||||
params.append(param_data)
|
all_params.append(param_data)
|
||||||
|
|
||||||
return params
|
return all_params
|
||||||
|
|
||||||
def prepare_grad_etc(self, text_encoder, unet):
|
def prepare_grad_etc(self, text_encoder, unet):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
|
@ -61,6 +61,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
for key in lora_sd.keys():
|
for key in lora_sd.keys():
|
||||||
if "lora_down" in key:
|
if "lora_down" in key:
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
|
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
||||||
|
|
||||||
# find original module for this lora
|
# find original module for this lora
|
||||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
||||||
@ -73,14 +74,18 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
down_weight = lora_sd[key]
|
down_weight = lora_sd[key]
|
||||||
up_weight = lora_sd[up_key]
|
up_weight = lora_sd[up_key]
|
||||||
|
|
||||||
|
dim = down_weight.size()[0]
|
||||||
|
alpha = lora_sd.get(alpha_key, dim)
|
||||||
|
scale = alpha / dim
|
||||||
|
|
||||||
# W <- W + U * D
|
# W <- W + U * D
|
||||||
weight = module.weight
|
weight = module.weight
|
||||||
if len(weight.size()) == 2:
|
if len(weight.size()) == 2:
|
||||||
# linear
|
# linear
|
||||||
weight = weight + ratio * (up_weight @ down_weight)
|
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||||
else:
|
else:
|
||||||
# conv2d
|
# conv2d
|
||||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
|
||||||
|
|
||||||
module.weight = torch.nn.Parameter(weight)
|
module.weight = torch.nn.Parameter(weight)
|
||||||
|
|
||||||
@ -88,20 +93,35 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
def merge_lora_models(models, ratios, merge_dtype):
|
def merge_lora_models(models, ratios, merge_dtype):
|
||||||
merged_sd = {}
|
merged_sd = {}
|
||||||
|
|
||||||
|
alpha = None
|
||||||
|
dim = None
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
print(f"loading: {model}")
|
||||||
lora_sd = load_state_dict(model, merge_dtype)
|
lora_sd = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
print(f"merging...")
|
print(f"merging...")
|
||||||
for key in lora_sd.keys():
|
for key in lora_sd.keys():
|
||||||
if key in merged_sd:
|
if 'alpha' in key:
|
||||||
assert merged_sd[key].size() == lora_sd[key].size(
|
if key in merged_sd:
|
||||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
|
||||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
else:
|
||||||
|
alpha = lora_sd[key].detach().numpy()
|
||||||
|
merged_sd[key] = lora_sd[key]
|
||||||
else:
|
else:
|
||||||
merged_sd[key] = lora_sd[key] * ratio
|
if key in merged_sd:
|
||||||
|
assert merged_sd[key].size() == lora_sd[key].size(
|
||||||
|
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||||
|
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
||||||
|
else:
|
||||||
|
if "lora_down" in key:
|
||||||
|
dim = lora_sd[key].size()[0]
|
||||||
|
merged_sd[key] = lora_sd[key] * ratio
|
||||||
|
|
||||||
return merged_sd
|
print(f"dim (rank): {dim}, alpha: {alpha}")
|
||||||
|
if alpha is None:
|
||||||
|
alpha = dim
|
||||||
|
|
||||||
|
return merged_sd, dim, alpha
|
||||||
|
|
||||||
|
|
||||||
def merge(args):
|
def merge(args):
|
||||||
@ -132,7 +152,7 @@ def merge(args):
|
|||||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||||
args.sd_model, 0, 0, save_dtype, vae)
|
args.sd_model, 0, 0, save_dtype, vae)
|
||||||
else:
|
else:
|
||||||
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
print(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||||
@ -145,7 +165,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--save_precision", type=str, default=None,
|
parser.add_argument("--save_precision", type=str, default=None,
|
||||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||||
parser.add_argument("--precision", type=str, default="float",
|
parser.add_argument("--precision", type=str, default="float",
|
||||||
choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度")
|
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
||||||
parser.add_argument("--sd_model", type=str, default=None,
|
parser.add_argument("--sd_model", type=str, default=None,
|
||||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
||||||
parser.add_argument("--save_to", type=str, default=None,
|
parser.add_argument("--save_to", type=str, default=None,
|
||||||
|
@ -3,6 +3,9 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@ -18,7 +21,23 @@ def collate_fn(examples):
|
|||||||
return examples[0]
|
return examples[0]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||||
|
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||||
|
|
||||||
|
if args.network_train_unet_only:
|
||||||
|
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
||||||
|
elif args.network_train_text_encoder_only:
|
||||||
|
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||||
|
else:
|
||||||
|
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||||
|
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
||||||
|
|
||||||
|
return logs
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
session_id = random.randint(0, 2**32)
|
||||||
|
training_started_at = time.time()
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
@ -88,7 +107,8 @@ def train(args):
|
|||||||
key, value = net_arg.split('=')
|
key, value = net_arg.split('=')
|
||||||
net_kwargs[key] = value
|
net_kwargs[key] = value
|
||||||
|
|
||||||
network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs)
|
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
||||||
|
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -206,21 +226,26 @@ def train(args):
|
|||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
|
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
||||||
|
"ss_training_started_at": training_started_at, # unix timestamp
|
||||||
|
"ss_output_name": args.output_name,
|
||||||
"ss_learning_rate": args.learning_rate,
|
"ss_learning_rate": args.learning_rate,
|
||||||
"ss_text_encoder_lr": args.text_encoder_lr,
|
"ss_text_encoder_lr": args.text_encoder_lr,
|
||||||
"ss_unet_lr": args.unet_lr,
|
"ss_unet_lr": args.unet_lr,
|
||||||
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
|
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
|
||||||
"ss_num_reg_images": train_dataset.num_reg_images,
|
"ss_num_reg_images": train_dataset.num_reg_images,
|
||||||
"ss_num_batches_per_epoch": len(train_dataloader),
|
"ss_num_batches_per_epoch": len(train_dataloader),
|
||||||
"ss_num_epochs": num_train_epochs,
|
"ss_num_epochs": num_train_epochs,
|
||||||
"ss_batch_size_per_device": args.train_batch_size,
|
"ss_batch_size_per_device": args.train_batch_size,
|
||||||
"ss_total_batch_size": total_batch_size,
|
"ss_total_batch_size": total_batch_size,
|
||||||
|
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
||||||
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||||
"ss_max_train_steps": args.max_train_steps,
|
"ss_max_train_steps": args.max_train_steps,
|
||||||
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
||||||
"ss_lr_scheduler": args.lr_scheduler,
|
"ss_lr_scheduler": args.lr_scheduler,
|
||||||
"ss_network_module": args.network_module,
|
"ss_network_module": args.network_module,
|
||||||
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
||||||
|
"ss_network_alpha": args.network_alpha, # some networks may not use this value
|
||||||
"ss_mixed_precision": args.mixed_precision,
|
"ss_mixed_precision": args.mixed_precision,
|
||||||
"ss_full_fp16": bool(args.full_fp16),
|
"ss_full_fp16": bool(args.full_fp16),
|
||||||
"ss_v2": bool(args.v2),
|
"ss_v2": bool(args.v2),
|
||||||
@ -232,10 +257,14 @@ def train(args):
|
|||||||
"ss_random_crop": bool(args.random_crop),
|
"ss_random_crop": bool(args.random_crop),
|
||||||
"ss_shuffle_caption": bool(args.shuffle_caption),
|
"ss_shuffle_caption": bool(args.shuffle_caption),
|
||||||
"ss_cache_latents": bool(args.cache_latents),
|
"ss_cache_latents": bool(args.cache_latents),
|
||||||
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
|
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
||||||
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
|
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
||||||
"ss_max_bucket_reso": args.max_bucket_reso,
|
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
||||||
"ss_seed": args.seed
|
"ss_seed": args.seed,
|
||||||
|
"ss_keep_tokens": args.keep_tokens,
|
||||||
|
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||||
|
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
||||||
|
"ss_training_comment": args.training_comment # will not be updated after training
|
||||||
}
|
}
|
||||||
|
|
||||||
# uncomment if another network is added
|
# uncomment if another network is added
|
||||||
@ -246,6 +275,7 @@ def train(args):
|
|||||||
sd_model_name = args.pretrained_model_name_or_path
|
sd_model_name = args.pretrained_model_name_or_path
|
||||||
if os.path.exists(sd_model_name):
|
if os.path.exists(sd_model_name):
|
||||||
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
||||||
|
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
|
||||||
sd_model_name = os.path.basename(sd_model_name)
|
sd_model_name = os.path.basename(sd_model_name)
|
||||||
metadata["ss_sd_model_name"] = sd_model_name
|
metadata["ss_sd_model_name"] = sd_model_name
|
||||||
|
|
||||||
@ -253,6 +283,7 @@ def train(args):
|
|||||||
vae_name = args.vae
|
vae_name = args.vae
|
||||||
if os.path.exists(vae_name):
|
if os.path.exists(vae_name):
|
||||||
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
||||||
|
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
|
||||||
vae_name = os.path.basename(vae_name)
|
vae_name = os.path.basename(vae_name)
|
||||||
metadata["ss_vae_name"] = vae_name
|
metadata["ss_vae_name"] = vae_name
|
||||||
|
|
||||||
@ -333,20 +364,20 @@ def train(args):
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if args.logging_dir is not None:
|
|
||||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||||
accelerator.log(logs, step=epoch+1)
|
accelerator.log(logs, step=epoch+1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
@ -417,11 +448,15 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
|
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
|
||||||
parser.add_argument("--network_dim", type=int, default=None,
|
parser.add_argument("--network_dim", type=int, default=None,
|
||||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
||||||
|
parser.add_argument("--network_alpha", type=float, default=1,
|
||||||
|
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)')
|
||||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
|
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
|
||||||
parser.add_argument("--network_train_text_encoder_only", action="store_true",
|
parser.add_argument("--network_train_text_encoder_only", action="store_true",
|
||||||
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
||||||
|
parser.add_argument("--training_comment", type=str, default=None,
|
||||||
|
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
@ -138,7 +138,7 @@ v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマ
|
|||||||
|
|
||||||
## 当リポジトリ内の画像生成スクリプトで生成する
|
## 当リポジトリ内の画像生成スクリプトで生成する
|
||||||
|
|
||||||
gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim(省略可)の各オプションを追加してください。意味は学習時と同様です。
|
gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
|
||||||
|
|
||||||
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
|
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user