Merge pull request #227 from bmaltais/dev

v20.8.1
This commit is contained in:
bmaltais 2023-02-23 19:26:16 -05:00 committed by GitHub
commit c7e99eb54b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 538 additions and 42 deletions

View File

@ -163,6 +163,15 @@ This will store your a backup file with your current locally installed pip packa
## Change History ## Change History
* 2023/02/23 (v20.8.1):
- Fix instability training issue in `train_network.py`.
- `fp16` training is probably not affected by this issue.
- Training with `float` for SD2.x models will work now. Also training with bf16 might be improved.
- This issue seems to have occurred in [PR#190](https://github.com/kohya-ss/sd-scripts/pull/190).
- Add some metadata to LoRA model. Thanks to space-nuko!
- Raise an error if optimizer options conflict (e.g. `--optimizer_type` and `--use_8bit_adam`.)
- Support ControlNet in `gen_img_diffusers.py` (no documentation yet.)
* 2023/02/22 (v20.8.0): * 2023/02/22 (v20.8.0):
- Add gui support for optimizers: `AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor` - Add gui support for optimizers: `AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor`
- Add gui support for `--noise_offset` - Add gui support for `--noise_offset`

View File

@ -285,8 +285,14 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
logs = {"avr_loss": loss_total / (step+1)}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
# print(lr_scheduler.optimizers)
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
logs["d"] = lr_scheduler.optimizers[0].param_groups[0]['d']
logs["lrD"] = lr_scheduler.optimizers[0].param_groups[0]['lr']
logs["gsq_weighted"] = lr_scheduler.optimizers[0].param_groups[0]['gsq_weighted']
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
# TODO moving averageにする # TODO moving averageにする

View File

@ -47,7 +47,7 @@ VGG(
""" """
import json import json
from typing import List, Optional, Union from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob import glob
import importlib import importlib
import inspect import inspect
@ -60,7 +60,6 @@ import math
import os import os
import random import random
import re import re
from typing import Any, Callable, List, Optional, Union
import diffusers import diffusers
import numpy as np import numpy as np
@ -81,6 +80,8 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import library.model_util as model_util import library.model_util as model_util
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14" TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@ -487,6 +488,9 @@ class PipelineLike():
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers) self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
# ControlNet
self.control_nets: List[ControlNetInfo] = []
# Textual Inversion # Textual Inversion
def add_token_replacement(self, target_token_id, rep_token_ids): def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids self.token_replacements[target_token_id] = rep_token_ids
@ -500,7 +504,11 @@ class PipelineLike():
new_tokens.append(token) new_tokens.append(token)
return new_tokens return new_tokens
def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets
# region xformersとか使う部分独自に書き換えるので関係なし # region xformersとか使う部分独自に書き換えるので関係なし
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self):
r""" r"""
Enable memory efficient attention as implemented in xformers. Enable memory efficient attention as implemented in xformers.
@ -752,7 +760,7 @@ class PipelineLike():
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None: if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
if isinstance(clip_guide_images, PIL.Image.Image): if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images] clip_guide_images = [clip_guide_images]
@ -765,7 +773,7 @@ class PipelineLike():
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
if len(image_embeddings_clip) == 1: if len(image_embeddings_clip) == 1:
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1)) image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
else: elif self.vgg16_guidance_scale > 0:
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に小さいか? size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に小さいか?
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images] clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
clip_guide_images = torch.cat(clip_guide_images, dim=0) clip_guide_images = torch.cat(clip_guide_images, dim=0)
@ -774,6 +782,10 @@ class PipelineLike():
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat'] image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
if len(image_embeddings_vgg16) == 1: if len(image_embeddings_vgg16) == 1:
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1)) image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
else:
# ControlNetのhintにguide imageを流用する
# 前処理はControlNet側で行う
pass
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps, self.device) self.scheduler.set_timesteps(num_inference_steps, self.device)
@ -864,11 +876,20 @@ class PipelineLike():
extra_step_kwargs["eta"] = eta extra_step_kwargs["eta"] = eta
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
if self.control_nets:
noise_pred = original_control_net.call_unet_and_control_net(
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
else:
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance # perform guidance
@ -1817,6 +1838,34 @@ def preprocess_mask(mask):
# return text_encoder # return text_encoder
class BatchDataBase(NamedTuple):
# バッチ分割が必要ないデータ
step: int
prompt: str
negative_prompt: str
seed: int
init_image: Any
mask_image: Any
clip_prompt: str
guide_image: Any
class BatchDataExt(NamedTuple):
# バッチ分割が必要なデータ
width: int
height: int
steps: int
scale: float
negative_scale: float
strength: float
network_muls: Tuple[float]
class BatchData(NamedTuple):
base: BatchDataBase
ext: BatchDataExt
def main(args): def main(args):
if args.fp16: if args.fp16:
dtype = torch.float16 dtype = torch.float16
@ -1995,11 +2044,13 @@ def main(args):
# networkを組み込む # networkを組み込む
if args.network_module: if args.network_module:
networks = [] networks = []
network_default_muls = []
for i, network_module in enumerate(args.network_module): for i, network_module in enumerate(args.network_module):
print("import network module:", network_module) print("import network module:", network_module)
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {} net_kwargs = {}
if args.network_args and i < len(args.network_args): if args.network_args and i < len(args.network_args):
@ -2014,7 +2065,7 @@ def main(args):
network_weight = args.network_weights[i] network_weight = args.network_weights[i]
print("load network weights from:", network_weight) print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight): if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f: with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
@ -2037,6 +2088,18 @@ def main(args):
else: else:
networks = [] networks = []
# ControlNetの処理
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
if args.opt_channels_last: if args.opt_channels_last:
print(f"set optimizing: channels last") print(f"set optimizing: channels last")
text_encoder.to(memory_format=torch.channels_last) text_encoder.to(memory_format=torch.channels_last)
@ -2050,9 +2113,14 @@ def main(args):
if vgg16_model is not None: if vgg16_model is not None:
vgg16_model.to(memory_format=torch.channels_last) vgg16_model.to(memory_format=torch.channels_last)
for cn in control_nets:
cn.unet.to(memory_format=torch.channels_last)
cn.net.to(memory_format=torch.channels_last)
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip, pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale, clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer) vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
pipe.set_control_nets(control_nets)
print("pipeline is ready.") print("pipeline is ready.")
if args.diffusers_xformers: if args.diffusers_xformers:
@ -2186,9 +2254,12 @@ def main(args):
prev_image = None # for VGG16 guided prev_image = None # for VGG16 guided
if args.guide_image_path is not None: if args.guide_image_path is not None:
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}") print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
guide_images = load_images(args.guide_image_path) guide_images = []
print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance") for p in args.guide_image_path:
guide_images.extend(load_images(p))
print(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0: if len(guide_images) == 0:
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
guide_images = None guide_images = None
@ -2219,33 +2290,37 @@ def main(args):
iter_seed = random.randint(0, 0x7fffffff) iter_seed = random.randint(0, 0x7fffffff)
# バッチ処理の関数 # バッチ処理の関数
def process_batch(batch, highres_fix, highres_1st=False): def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
batch_size = len(batch) batch_size = len(batch)
# highres_fixの処理 # highres_fixの処理
if highres_fix and not highres_1st: if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出す # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
print("process 1st stage1") print("process 1st stage1")
batch_1st = [] batch_1st = []
for params1, (width, height, steps, scale, negative_scale, strength) in batch: for base, ext in batch:
width_1st = int(width * args.highres_fix_scale + .5) width_1st = int(ext.width * args.highres_fix_scale + .5)
height_1st = int(height * args.highres_fix_scale + .5) height_1st = int(ext.height * args.highres_fix_scale + .5)
width_1st = width_1st - width_1st % 32 width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32 height_1st = height_1st - height_1st % 32
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
ext.negative_scale, ext.strength, ext.network_muls)
batch_1st.append(BatchData(base, ext_1st))
images_1st = process_batch(batch_1st, True, True) images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する # 2nd stageのバッチを作成して以下処理する
print("process 2nd stage1") print("process 2nd stage1")
batch_2nd = [] batch_2nd = []
for i, (b1, image) in enumerate(zip(batch, images_1st)): for i, (bd, image) in enumerate(zip(batch, images_1st)):
image = image.resize((width, height), resample=PIL.Image.LANCZOS) image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
(step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1 bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2)) batch_2nd.append(bd_2nd)
batch = batch_2nd batch = batch_2nd
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width, # このバッチの情報を取り出す
height, steps, scale, negative_scale, strength) = batch[0] (step_first, _, _, _, init_image, mask_image, _, guide_image), \
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
prompts = [] prompts = []
@ -2295,6 +2370,10 @@ def main(args):
all_masks_are_same = mask_images[-2] is mask_image all_masks_are_same = mask_images[-2] is mask_image
if guide_image is not None: if guide_image is not None:
if type(guide_image) is list:
guide_images.extend(guide_image)
all_guide_images_are_same = False
else:
guide_images.append(guide_image) guide_images.append(guide_image)
if i > 0 and all_guide_images_are_same: if i > 0 and all_guide_images_are_same:
all_guide_images_are_same = guide_images[-2] is guide_image all_guide_images_are_same = guide_images[-2] is guide_image
@ -2320,7 +2399,19 @@ def main(args):
if guide_images is not None and all_guide_images_are_same: if guide_images is not None and all_guide_images_are_same:
guide_images = guide_images[0] guide_images = guide_images[0]
# ControlNet使用時はguide imageをリサイズする
if control_nets:
# TODO resampleのメソッド
guide_images = guide_images if type(guide_images) == list else [guide_images]
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
if len(guide_images) == 1:
guide_images = guide_images[0]
# generate # generate
if networks:
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
if highres_1st and not args.highres_fix_save_1st: if highres_1st and not args.highres_fix_save_1st:
@ -2398,6 +2489,7 @@ def main(args):
strength = 0.8 if args.strength is None else args.strength strength = 0.8 if args.strength is None else args.strength
negative_prompt = "" negative_prompt = ""
clip_prompt = None clip_prompt = None
network_muls = None
prompt_args = prompt.strip().split(' --') prompt_args = prompt.strip().split(' --')
prompt = prompt_args[0] prompt = prompt_args[0]
@ -2461,6 +2553,15 @@ def main(args):
clip_prompt = m.group(1) clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}") print(f"clip prompt: {clip_prompt}")
continue continue
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
if m: # network multiplies
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
print(f"network mul: {network_muls}")
continue
except ValueError as ex: except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}") print(f"Exception in parsing / 解析エラー: {parg}")
print(ex) print(ex)
@ -2498,6 +2599,11 @@ def main(args):
mask_image = mask_images[global_step % len(mask_images)] mask_image = mask_images[global_step % len(mask_images)]
if guide_images is not None: if guide_images is not None:
if control_nets: # 複数件の場合あり
c = len(control_nets)
p = global_step % (len(guide_images) // c)
guide_image = guide_images[p * c:p * c + c]
else:
guide_image = guide_images[global_step % len(guide_images)] guide_image = guide_images[global_step % len(guide_images)]
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
if prev_image is None: if prev_image is None:
@ -2506,9 +2612,8 @@ def main(args):
print("Use previous image as guide image.") print("Use previous image as guide image.")
guide_image = prev_image guide_image = prev_image
# TODO named tupleか何かにする b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
(width, height, steps, scale, negative_scale, strength))
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
process_batch(batch_data, highres_fix) process_batch(batch_data, highres_fix)
batch_data.clear() batch_data.clear()
@ -2578,12 +2683,15 @@ if __name__ == '__main__':
parser.add_argument("--opt_channels_last", action='store_true', parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannels lastを指定し最適化する') help='set channels last option to model / モデルにchannels lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, nargs='*', parser.add_argument("--network_module", type=str, default=None, nargs='*',
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
parser.add_argument("--network_weights", type=str, default=None, nargs='*', parser.add_argument("--network_weights", type=str, default=None, nargs='*',
help='Hypernetwork weights to load / Hypernetworkの重み') help='additional network weights to load / 追加ネットワークの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') parser.add_argument("--network_mul", type=float, default=None, nargs='*',
help='additional network multiplier / 追加ネットワークの効果の倍率')
parser.add_argument("--network_args", type=str, default=None, nargs='*', parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--network_show_meta", action='store_true',
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*', parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings') help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
@ -2597,7 +2705,8 @@ if __name__ == '__main__':
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する') help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
parser.add_argument("--vgg16_guidance_layer", type=int, default=20, parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)') help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像") parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
parser.add_argument("--highres_fix_scale", type=float, default=None, parser.add_argument("--highres_fix_scale", type=float, default=None,
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする") help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
parser.add_argument("--highres_fix_steps", type=int, default=28, parser.add_argument("--highres_fix_steps", type=int, default=28,
@ -2607,5 +2716,13 @@ if __name__ == '__main__':
parser.add_argument("--negative_scale", type=float, default=None, parser.add_argument("--negative_scale", type=float, default=None,
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する") help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
help='ControlNet models to use / 使用するControlNetのモデル名')
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1372,8 +1372,8 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--optimizer_type", type=str, default="AdamW", parser.add_argument("--optimizer_type", type=str, default="",
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor") help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
# backward compatibility # backward compatibility
parser.add_argument("--use_8bit_adam", action="store_true", parser.add_argument("--use_8bit_adam", action="store_true",
@ -1532,11 +1532,16 @@ def get_optimizer(args, trainable_params):
optimizer_type = args.optimizer_type optimizer_type = args.optimizer_type
if args.use_8bit_adam: if args.use_8bit_adam:
print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます") assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
optimizer_type = "AdamW8bit" optimizer_type = "AdamW8bit"
elif args.use_lion_optimizer: elif args.use_lion_optimizer:
print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます") assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
optimizer_type = "Lion" optimizer_type = "Lion"
if optimizer_type is None or optimizer_type == "":
optimizer_type = "AdamW"
optimizer_type = optimizer_type.lower() optimizer_type = optimizer_type.lower()
# 引数を分解するboolとfloat、tupleのみ対応 # 引数を分解するboolとfloat、tupleのみ対応
@ -1557,7 +1562,7 @@ def get_optimizer(args, trainable_params):
value = tuple(value) value = tuple(value)
optimizer_kwargs[key] = value optimizer_kwargs[key] = value
print("optkwargs:", optimizer_kwargs) # print("optkwargs:", optimizer_kwargs)
lr = args.learning_rate lr = args.learning_rate
@ -1633,7 +1638,7 @@ def get_optimizer(args, trainable_params):
if optimizer_kwargs["relative_step"]: if optimizer_kwargs["relative_step"]:
print(f"relative_step is true / relative_stepがtrueです") print(f"relative_step is true / relative_stepがtrueです")
if lr != 0.0: if lr != 0.0:
print(f"learning rate is used as initial_lr / 指定したlearning rate はinitial_lrとして使用されます: {lr}") print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
args.learning_rate = None args.learning_rate = None
# trainable_paramsがgroupだった時の処理lrを削除する # trainable_paramsがgroupだった時の処理lrを削除する

View File

@ -126,6 +126,11 @@ class LoRANetwork(torch.nn.Module):
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name) names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open

View File

@ -14,6 +14,7 @@ altair==4.2.2
easygui==0.98.3 easygui==0.98.3
tk==0.1.0 tk==0.1.0
lion-pytorch==0.0.6 lion-pytorch==0.0.6
dadaptation==1.5
# for BLIP captioning # for BLIP captioning
requests==2.28.2 requests==2.28.2
timm==0.6.12 timm==0.6.12

24
tools/canny.py Normal file
View File

@ -0,0 +1,24 @@
import argparse
import cv2
def canny(args):
img = cv2.imread(args.input)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
canny_img = cv2.Canny(img, args.thres1, args.thres2)
# canny_img = 255 - canny_img
cv2.imwrite(args.output, canny_img)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default=None, help="input path")
parser.add_argument("--output", type=str, default=None, help="output path")
parser.add_argument("--thres1", type=int, default=32, help="thres1")
parser.add_argument("--thres2", type=int, default=224, help="thres2")
args = parser.parse_args()
canny(args)

View File

@ -0,0 +1,320 @@
from typing import List, NamedTuple, Any
import numpy as np
import cv2
import torch
from safetensors.torch import load_file
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
import library.model_util as model_util
class ControlNetInfo(NamedTuple):
unet: Any
net: Any
prep: Any
weight: float
ratio: float
class ControlNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# make control model
self.control_model = torch.nn.Module()
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
zero_convs = torch.nn.ModuleList()
for i, dim in enumerate(dims):
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
zero_convs.append(sub_list)
self.control_model.add_module("zero_convs", zero_convs)
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
dims = [16, 16, 32, 32, 96, 96, 256, 320]
strides = [1, 1, 2, 1, 2, 1, 2, 1]
prev_dim = 3
input_hint_block = torch.nn.Sequential()
for i, (dim, stride) in enumerate(zip(dims, strides)):
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
if i < len(dims) - 1:
input_hint_block.append(torch.nn.SiLU())
prev_dim = dim
self.control_model.add_module("input_hint_block", input_hint_block)
def load_control_net(v2, unet, model):
device = unet.device
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む
print(f"ControlNet: loading control SD model : {model}")
if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model)
else:
ctrl_sd_sd = torch.load(model, map_location='cpu')
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference")
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
for key in list(ctrl_unet_sd_sd.keys()):
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
zero_conv_sd = {}
for key in list(ctrl_sd_sd.keys()):
if key.startswith("control_"):
unet_key = "model.diffusion_" + key[len("control_"):]
if unet_key not in ctrl_unet_sd_sd: # zero conv
zero_conv_sd[key] = ctrl_sd_sd[key]
continue
if is_difference: # Transfer Control
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
else:
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
unet_config = model_util.create_unet_diffusers_config(v2)
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
# ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info)
# U-Net以外のControlNetを作成する
# TODO support middle only
ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info)
ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype)
return ctrl_unet, ctrl_net
def load_preprocess(prep_type: str):
if prep_type is None or prep_type.lower() == "none":
return None
if prep_type.startswith("canny"):
args = prep_type.split("_")
th1 = int(args[1]) if len(args) >= 2 else 63
th2 = int(args[2]) if len(args) >= 3 else 191
def canny(img):
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return cv2.Canny(img, th1, th2)
return canny
print("Unsupported prep type:", prep_type)
return None
def preprocess_ctrl_net_hint_image(image):
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, ::-1].copy() # rgb to bgr
image = image[None].transpose(0, 3, 1, 2) # nchw
image = torch.from_numpy(image)
return image # 0 to 1
def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
guided_hints = []
for i, cnet_info in enumerate(control_nets):
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
b_hints = []
if len(hints) == 1: # すべて同じ画像をhintとして使う
hint = hints[0]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints = [hint for _ in range(b_size)]
else:
for bi in range(b_size):
hint = hints[(bi * len(control_nets) + i) % len(hints)]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints.append(hint)
b_hints = torch.cat(b_hints, dim=0)
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
guided_hints.append(guided_hint)
return guided_hints
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
# ControlNet
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
cnet_cnt = len(control_nets)
cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx]
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states)
guided_hint = guided_hints[cnet_idx]
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
outs = [o * cnet_info.weight for o in outs]
# U-Net
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
"""
# これはmergeのバージョン
# ControlNet
cnet_outs_list = []
for i, cnet_info in enumerate(control_nets):
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
continue
guided_hint = guided_hints[i]
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
for i in range(len(outs)):
outs[i] *= cnet_info.weight
cnet_outs_list.append(outs)
count = len(cnet_outs_list)
if count == 0:
return original_unet(sample, timestep, encoder_hidden_states)
# sum of controlnets
for i in range(1, count):
cnet_outs_list[0] += cnet_outs_list[i]
# U-Net
return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
"""
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
# copy from UNet2DConditionModel
default_overall_up_factor = 2**unet.num_upsamplers
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 0. center input if necessary
if unet.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = unet.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=unet.dtype)
emb = unet.time_embedding(t_emb)
outs = [] # output of ControlNet
zc_idx = 0
# 2. pre-process
sample = unet.conv_in(sample)
if is_control_net:
sample += guided_hint
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
zc_idx += 1
# 3. down
down_block_res_samples = (sample,)
for downsample_block in unet.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_control_net:
for rs in res_samples:
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
zc_idx += 1
down_block_res_samples += res_samples
# 4. mid
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
if is_control_net:
outs.append(control_net.control_model.middle_block_out[0](sample))
return outs
if not is_control_net:
sample += ctrl_outs.pop()
# 5. up
for i, upsample_block in enumerate(unet.up_blocks):
is_final_block = i == len(unet.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if not is_control_net and len(ctrl_outs) > 0:
res_samples = list(res_samples)
apply_ctrl_outs = ctrl_outs[-len(res_samples):]
ctrl_outs = ctrl_outs[:-len(res_samples)]
for j in range(len(res_samples)):
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
res_samples = tuple(res_samples)
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = unet.conv_norm_out(sample)
sample = unet.conv_act(sample)
sample = unet.conv_out(sample)
return UNet2DConditionOutput(sample=sample)

View File

@ -36,8 +36,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr-textencoder"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
logs["lr/d*lr-unet"] = lr_scheduler.optimizers[-1].param_groups[1]['d']*lr_scheduler.optimizers[-1].param_groups[1]['lr']
return logs return logs
@ -276,9 +275,11 @@ def train(args):
"ss_shuffle_caption": bool(args.shuffle_caption), "ss_shuffle_caption": bool(args.shuffle_caption),
"ss_cache_latents": bool(args.cache_latents), "ss_cache_latents": bool(args.cache_latents),
"ss_enable_bucket": bool(train_dataset.enable_bucket), "ss_enable_bucket": bool(train_dataset.enable_bucket),
"ss_bucket_no_upscale": bool(train_dataset.bucket_no_upscale),
"ss_min_bucket_reso": train_dataset.min_bucket_reso, "ss_min_bucket_reso": train_dataset.min_bucket_reso,
"ss_max_bucket_reso": train_dataset.max_bucket_reso, "ss_max_bucket_reso": train_dataset.max_bucket_reso,
"ss_seed": args.seed, "ss_seed": args.seed,
"ss_lowram": args.lowram,
"ss_keep_tokens": args.keep_tokens, "ss_keep_tokens": args.keep_tokens,
"ss_noise_offset": args.noise_offset, "ss_noise_offset": args.noise_offset,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
@ -287,7 +288,13 @@ def train(args):
"ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment, # will not be updated after training "ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else "") "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
"ss_max_grad_norm": args.max_grad_norm,
"ss_caption_dropout_rate": args.caption_dropout_rate,
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
} }
# uncomment if another network is added # uncomment if another network is added
@ -362,7 +369,7 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual # Predict the noise residual
with autocast(): with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization: if args.v_parameterization:
@ -423,6 +430,7 @@ def train(args):
def save_func(): def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
metadata["ss_training_finished_at"] = str(time.time())
print(f"saving checkpoint: {ckpt_file}") print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
@ -440,6 +448,7 @@ def train(args):
# end of epoch # end of epoch
metadata["ss_epoch"] = str(num_train_epochs) metadata["ss_epoch"] = str(num_train_epochs)
metadata["ss_training_finished_at"] = str(time.time())
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process: