Add support for `network_alpha` under the Training tab and support for `--training_comment` under the Folders tab.
This commit is contained in:
bmaltais 2023-01-22 10:18:00 -05:00
parent fcad6bfd98
commit 2ca17f69dd
10 changed files with 227 additions and 63 deletions

View File

@ -116,7 +116,23 @@ Once you have created the LoRA network you can generate images via auto1111 by i
## Change history ## Change history
* 2023/01/16 (v20.3.0) * 2023/01/22 (v20.4.0):
- Add support for `network_alpha` under the Training tab and support for `--training_comment` under the Folders tab.
- Add ``--network_alpha`` option to specify ``alpha`` value to prevent underflows for stable training. Thanks to CCRcmcpe!
- Details of the issue are described in https://github.com/kohya-ss/sd-webui-additional-networks/issues/49 .
- The default value is ``1``, scale ``1 / rank (or dimension)``. Set same value as ``network_dim`` for same behavior to old version.
- LoRA with a large dimension (rank) seems to require a higher learning rate with ``alpha=1`` (e.g. 1e-3 for 128-dim, still investigating). 
- For generating images in Web UI, __the latest version of the extension ``sd-webui-additional-networks`` (v0.3.0 or later) is required for the models trained with this release or later.__
- Add logging for the learning rate for U-Net and Text Encoder independently, and for running average epoch loss. Thanks to mgz-dev!
- Add more metadata such as dataset/reg image dirs, session ID, output name etc... See https://github.com/kohya-ss/sd-scripts/pull/77 for details. Thanks to space-nuko!
- __Now the metadata includes the folder name (the basename of the folder contains image files, not fullpath).__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
- Add ``--training_comment`` option. You can specify an arbitrary string and refer to it by the extension.
It seems that the Stable Diffusion web UI now supports image generation using the LoRA model learned in this repository.
Note: At this time, it appears that models learned with version 0.4.0 are not supported. If you want to use the generation function of the web UI, please continue to use version 0.3.2. Also, it seems that LoRA models for SD2.x are not supported.
* 2023/01/16 (v20.3.0):
- Fix a part of LoRA modules are not trained when ``gradient_checkpointing`` is enabled. - Fix a part of LoRA modules are not trained when ``gradient_checkpointing`` is enabled.
- Add ``--save_last_n_epochs_state`` option. You can specify how many state folders to keep, apart from how many models to keep. Thanks to shirayu! - Add ``--save_last_n_epochs_state`` option. You can specify how many state folders to keep, apart from how many models to keep. Thanks to shirayu!
- Fix Text Encoder training stops at ``max_train_steps`` even if ``max_train_epochs`` is set in `train_db.py``. - Fix Text Encoder training stops at ``max_train_steps`` even if ``max_train_epochs`` is set in `train_db.py``.

View File

@ -1981,7 +1981,6 @@ def main(args):
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
net_kwargs = {} net_kwargs = {}
if args.network_args and i < len(args.network_args): if args.network_args and i < len(args.network_args):
@ -1992,22 +1991,22 @@ def main(args):
key, value = net_arg.split("=") key, value = net_arg.split("=")
net_kwargs[key] = value net_kwargs[key] = value
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
if args.network_weights and i < len(args.network_weights): if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i] network_weight = args.network_weights[i]
print("load network weights from:", network_weight) print("load network weights from:", network_weight)
if os.path.splitext(network_weight)[1] == '.safetensors': if model_util.is_safetensors(network_weight):
from safetensors.torch import safe_open from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f: with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
if metadata is not None: if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}") print(f"metadata for: {network_weight}: {metadata}")
network.load_weights(network_weight) network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
@ -2518,16 +2517,14 @@ if __name__ == '__main__':
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
parser.add_argument("--diffusers_xformers", action='store_true', parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用するHypernetwork利用不可') help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用するHypernetwork利用不可')
parser.add_argument("--opt_channels_last", action='store_true', parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannles lastを指定し最適化する') help='set channels last option to model / モデルにchannels lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, nargs='*', parser.add_argument("--network_module", type=str, default=None, nargs='*',
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
parser.add_argument("--network_weights", type=str, default=None, nargs='*', parser.add_argument("--network_weights", type=str, default=None, nargs='*',
help='Hypernetwork weights to load / Hypernetworkの重み') help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_args", type=str, default=None, nargs='*', parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')

View File

@ -11,6 +11,7 @@ import glob
import math import math
import os import os
import random import random
import hashlib
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -79,6 +80,11 @@ class BaseDataset(torch.utils.data.Dataset):
self.debug_dataset = debug_dataset self.debug_dataset = debug_dataset
self.random_crop = random_crop self.random_crop = random_crop
self.token_padding_disabled = False self.token_padding_disabled = False
self.dataset_dirs_info = {}
self.reg_dataset_dirs_info = {}
self.enable_bucket = False
self.min_bucket_reso = None
self.max_bucket_reso = None
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@ -463,6 +469,8 @@ class DreamBoothDataset(BaseDataset):
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso) (self.width, self.height), min_bucket_reso, max_bucket_reso)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
else: else:
self.bucket_resos = [(self.width, self.height)] self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height] self.bucket_aspect_ratios = [self.width / self.height]
@ -523,6 +531,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path) info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info) self.register_image(info)
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.") print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images self.num_train_images = num_train_images
@ -539,6 +548,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path) info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info) reg_infos.append(info)
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.") print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images: if num_train_images < num_reg_images:
@ -611,6 +621,8 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = len(metadata) * dataset_repeats self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0 self.num_reg_images = 0
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
# check existence of all npz files # check existence of all npz files
if not self.color_aug: if not self.color_aug:
npz_any = False npz_any = False
@ -653,6 +665,8 @@ class FineTuningDataset(BaseDataset):
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso) (self.width, self.height), min_bucket_reso, max_bucket_reso)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
else: else:
self.bucket_resos = [(self.width, self.height)] self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height] self.bucket_aspect_ratios = [self.width / self.height]
@ -665,6 +679,9 @@ class FineTuningDataset(BaseDataset):
self.bucket_resos.sort() self.bucket_resos.sort()
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos] self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
self.min_bucket_reso = min([min(reso) for reso in resos])
self.max_bucket_reso = max([max(reso) for reso in resos])
def image_key_to_npz_file(self, image_key): def image_key_to_npz_file(self, image_key):
base_name = os.path.splitext(image_key)[0] base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + '.npz' npz_file_norm = base_name + '.npz'
@ -749,9 +766,9 @@ def default(val, d):
def model_hash(filename): def model_hash(filename):
"""Old model hash used by stable-diffusion-webui"""
try: try:
with open(filename, "rb") as file: with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256() m = hashlib.sha256()
file.seek(0x100000) file.seek(0x100000)
@ -761,6 +778,18 @@ def model_hash(filename):
return 'NOFILE' return 'NOFILE'
def calculate_sha256(filename):
"""New model hash used by stable-diffusion-webui"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
# flash attention forwards and backwards # flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135 # https://arxiv.org/abs/2205.14135
@ -1029,7 +1058,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--save_every_n_epochs", type=int, default=None, parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
parser.add_argument("--save_state", action="store_true", parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
@ -1048,8 +1078,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします") parser.add_argument("--max_train_epochs", type=int, default=None,
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります") help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true", parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする") help="enable gradient checkpointing / grandient checkpointingを有効にする")

View File

@ -88,6 +88,8 @@ def save_configuration(
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha,
training_comment,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -175,6 +177,8 @@ def open_configuration(
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha,
training_comment,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -246,6 +250,8 @@ def train_model(
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha,
training_comment,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -358,6 +364,9 @@ def train_model(
run_cmd += f' --resolution={max_resolution}' run_cmd += f' --resolution={max_resolution}'
run_cmd += f' --output_dir="{output_dir}"' run_cmd += f' --output_dir="{output_dir}"'
run_cmd += f' --logging_dir="{logging_dir}"' run_cmd += f' --logging_dir="{logging_dir}"'
run_cmd += f' --network_alpha="{network_alpha}"'
if not training_comment == '':
run_cmd += f' --training_comment="{training_comment}"'
if not stop_text_encoder_training == 0: if not stop_text_encoder_training == 0:
run_cmd += ( run_cmd += (
f' --stop_text_encoder_training={stop_text_encoder_training}' f' --stop_text_encoder_training={stop_text_encoder_training}'
@ -518,10 +527,15 @@ def lora_tab(
with gr.Row(): with gr.Row():
output_name = gr.Textbox( output_name = gr.Textbox(
label='Model output name', label='Model output name',
placeholder='Name of the model to output', placeholder='(Name of the model to output)',
value='last', value='last',
interactive=True, interactive=True,
) )
training_comment = gr.Textbox(
label='Training comment',
placeholder='(Optional) Add training comment to be included in metadata',
interactive=True,
)
train_data_dir.change( train_data_dir.change(
remove_doublequote, remove_doublequote,
inputs=[train_data_dir], inputs=[train_data_dir],
@ -588,11 +602,19 @@ def lora_tab(
network_dim = gr.Slider( network_dim = gr.Slider(
minimum=1, minimum=1,
maximum=128, maximum=128,
label='Network Dimension', label='Network Rank (Dimension)',
value=8, value=8,
step=1, step=1,
interactive=True, interactive=True,
) )
network_alpha = gr.Slider(
minimum=1,
maximum=128,
label='Network Alpha',
value=1,
step=1,
interactive=True,
)
with gr.Row(): with gr.Row():
max_resolution = gr.Textbox( max_resolution = gr.Textbox(
label='Max resolution', label='Max resolution',
@ -703,6 +725,8 @@ def lora_tab(
max_token_length, max_token_length,
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha,
training_comment,
] ]
button_open_config.click( button_open_config.click(

View File

@ -15,12 +15,13 @@ def main(file):
keys = list(sd.keys()) keys = list(sd.keys())
for key in keys: for key in keys:
if 'lora_up' in key: if 'lora_up' in key or 'lora_down' in key:
values.append((key, sd[key])) values.append((key, sd[key]))
print(f"number of LoRA-up modules: {len(values)}") print(f"number of LoRA modules: {len(values)}")
for key, value in values: for key, value in values:
print(f"{key},{torch.mean(torch.abs(value))}") value = value.to(torch.float32)
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -44,9 +44,9 @@ def svd(args):
print(f"loading SD model : {args.model_tuned}") print(f"loading SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
# create LoRA network to extract weights # create LoRA network to extract weights: Use dim (rank) as alpha
lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o) lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t) lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
assert len(lora_network_o.text_encoder_loras) == len( assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース " lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
@ -116,6 +116,9 @@ def svd(args):
print(f"LoRA has {len(lora_sd)} weights.") print(f"LoRA has {len(lora_sd)} weights.")
for key in list(lora_sd.keys()): for key in list(lora_sd.keys()):
if "alpha" in key:
continue
lora_name = key.split('.')[0] lora_name = key.split('.')[0]
i = 0 if "lora_up" in key else 1 i = 0 if "lora_up" in key else 1
@ -124,7 +127,7 @@ def svd(args):
if len(lora_sd[key].size()) == 4: if len(lora_sd[key].size()) == 4:
weights = weights.unsqueeze(2).unsqueeze(3) weights = weights.unsqueeze(2).unsqueeze(3)
assert weights.size() == lora_sd[key].size() assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
lora_sd[key] = weights lora_sd[key] = weights
# load state dict to LoRA and save it # load state dict to LoRA and save it
@ -135,7 +138,10 @@ def svd(args):
if dir_name and not os.path.exists(dir_name): if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
lora_network_o.save_weights(args.save_to, save_dtype, {}) # minimum metadata
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {args.save_to}")
@ -151,8 +157,8 @@ if __name__ == '__main__':
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors") help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors")
parser.add_argument("--save_to", type=str, default=None, parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数デフォルト4") parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args() args = parser.parse_args()
svd(args) svd(args)

View File

@ -13,9 +13,11 @@ class LoRAModule(torch.nn.Module):
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
""" """
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
""" if alpha == 0 or None, alpha is rank (no scaling). """
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels in_dim = org_module.in_channels
@ -28,6 +30,12 @@ class LoRAModule(torch.nn.Module):
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
# same as microsoft's # same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight) torch.nn.init.zeros_(self.lora_up.weight)
@ -41,13 +49,37 @@ class LoRAModule(torch.nn.Module):
del self.org_module del self.org_module
def forward(self, x): def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim) network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location='cpu')
# get dim (rank)
network_alpha = None
network_dim = None
for key, value in weights_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is None:
network_alpha = network_dim
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
network.weights_sd = weights_sd
return network return network
@ -57,10 +89,11 @@ class LoRANetwork(torch.nn.Module):
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = 'lora_te'
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
super().__init__() super().__init__()
self.multiplier = multiplier self.multiplier = multiplier
self.lora_dim = lora_dim self.lora_dim = lora_dim
self.alpha = alpha
# create module instances # create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
@ -71,7 +104,7 @@ class LoRANetwork(torch.nn.Module):
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_') lora_name = lora_name.replace('.', '_')
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim) lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
loras.append(lora) loras.append(lora)
return loras return loras
@ -149,21 +182,21 @@ class LoRANetwork(torch.nn.Module):
return params return params
self.requires_grad_(True) self.requires_grad_(True)
params = [] all_params = []
if self.text_encoder_loras: if self.text_encoder_loras:
param_data = {'params': enumerate_params(self.text_encoder_loras)} param_data = {'params': enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None: if text_encoder_lr is not None:
param_data['lr'] = text_encoder_lr param_data['lr'] = text_encoder_lr
params.append(param_data) all_params.append(param_data)
if self.unet_loras: if self.unet_loras:
param_data = {'params': enumerate_params(self.unet_loras)} param_data = {'params': enumerate_params(self.unet_loras)}
if unet_lr is not None: if unet_lr is not None:
param_data['lr'] = unet_lr param_data['lr'] = unet_lr
params.append(param_data) all_params.append(param_data)
return params return all_params
def prepare_grad_etc(self, text_encoder, unet): def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True) self.requires_grad_(True)

View File

@ -61,6 +61,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
for key in lora_sd.keys(): for key in lora_sd.keys():
if "lora_down" in key: if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora # find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
@ -73,14 +74,18 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
down_weight = lora_sd[key] down_weight = lora_sd[key]
up_weight = lora_sd[up_key] up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D # W <- W + U * D
weight = module.weight weight = module.weight
if len(weight.size()) == 2: if len(weight.size()) == 2:
# linear # linear
weight = weight + ratio * (up_weight @ down_weight) weight = weight + ratio * (up_weight @ down_weight) * scale
else: else:
# conv2d # conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight) module.weight = torch.nn.Parameter(weight)
@ -88,20 +93,35 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
def merge_lora_models(models, ratios, merge_dtype): def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {} merged_sd = {}
alpha = None
dim = None
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype) lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...") print(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else:
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key]
else:
if key in merged_sd: if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size( assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else: else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio merged_sd[key] = lora_sd[key] * ratio
return merged_sd print(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
return merged_sd, dim, alpha
def merge(args): def merge(args):
@ -132,7 +152,7 @@ def merge(args):
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae) args.sd_model, 0, 0, save_dtype, vae)
else: else:
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}") print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)
@ -145,7 +165,7 @@ if __name__ == '__main__':
parser.add_argument("--save_precision", type=str, default=None, parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float", parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度") choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--sd_model", type=str, default=None, parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
parser.add_argument("--save_to", type=str, default=None, parser.add_argument("--save_to", type=str, default=None,

View File

@ -3,6 +3,9 @@ import argparse
import gc import gc
import math import math
import os import os
import random
import time
import json
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -18,7 +21,23 @@ def collate_fn(examples):
return examples[0] return examples[0]
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.network_train_unet_only:
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
else:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
return logs
def train(args): def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
@ -88,7 +107,8 @@ def train(args):
key, value = net_arg.split('=') key, value = net_arg.split('=')
net_kwargs[key] = value net_kwargs[key] = value
network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs) # if a new network is added in future, add if ~ then blocks for each network (;'∀')
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None: if network is None:
return return
@ -206,21 +226,26 @@ def train(args):
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = { metadata = {
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate, "ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr, "ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr, "ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data "ss_num_train_images": train_dataset.num_train_images, # includes repeating
"ss_num_reg_images": train_dataset.num_reg_images, "ss_num_reg_images": train_dataset.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader), "ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs, "ss_num_epochs": num_train_epochs,
"ss_batch_size_per_device": args.train_batch_size, "ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size, "ss_total_batch_size": total_batch_size,
"ss_gradient_checkpointing": args.gradient_checkpointing,
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps, "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
"ss_max_train_steps": args.max_train_steps, "ss_max_train_steps": args.max_train_steps,
"ss_lr_warmup_steps": args.lr_warmup_steps, "ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler, "ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module, "ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not use this value
"ss_mixed_precision": args.mixed_precision, "ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16), "ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2), "ss_v2": bool(args.v2),
@ -232,10 +257,14 @@ def train(args):
"ss_random_crop": bool(args.random_crop), "ss_random_crop": bool(args.random_crop),
"ss_shuffle_caption": bool(args.shuffle_caption), "ss_shuffle_caption": bool(args.shuffle_caption),
"ss_cache_latents": bool(args.cache_latents), "ss_cache_latents": bool(args.cache_latents),
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT "ss_enable_bucket": bool(train_dataset.enable_bucket),
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset "ss_min_bucket_reso": train_dataset.min_bucket_reso,
"ss_max_bucket_reso": args.max_bucket_reso, "ss_max_bucket_reso": train_dataset.max_bucket_reso,
"ss_seed": args.seed "ss_seed": args.seed,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_training_comment": args.training_comment # will not be updated after training
} }
# uncomment if another network is added # uncomment if another network is added
@ -246,6 +275,7 @@ def train(args):
sd_model_name = args.pretrained_model_name_or_path sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name): if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
sd_model_name = os.path.basename(sd_model_name) sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name metadata["ss_sd_model_name"] = sd_model_name
@ -253,6 +283,7 @@ def train(args):
vae_name = args.vae vae_name = args.vae
if os.path.exists(vae_name): if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name) metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
vae_name = os.path.basename(vae_name) vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name metadata["ss_vae_name"] = vae_name
@ -333,20 +364,20 @@ def train(args):
global_step += 1 global_step += 1
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
loss_total += current_loss loss_total += current_loss
avr_loss = loss_total / (step+1) avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)} logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1) accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@ -417,11 +448,15 @@ if __name__ == '__main__':
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
parser.add_argument("--network_dim", type=int, default=None, parser.add_argument("--network_dim", type=int, default=None,
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_alpha", type=float, default=1,
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定')
parser.add_argument("--network_args", type=str, default=None, nargs='*', parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument("--network_train_text_encoder_only", action="store_true", parser.add_argument("--network_train_text_encoder_only", action="store_true",
help="only training Text Encoder part / Text Encoder関連部分のみ学習する") help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
parser.add_argument("--training_comment", type=str, default=None,
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)

View File

@ -138,7 +138,7 @@ v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマ
## 当リポジトリ内の画像生成スクリプトで生成する ## 当リポジトリ内の画像生成スクリプトで生成する
gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim省略可の各オプションを追加してください。意味は学習時と同様です。 gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。 --network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。