diff --git a/fine_tune.py b/fine_tune.py index 80290e7..426fb09 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -285,8 +285,14 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + logs = {"avr_loss": loss_total / (step+1)} if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + # print(lr_scheduler.optimizers) logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] + logs["d"] = lr_scheduler.optimizers[0].param_groups[0]['d'] + logs["lrD"] = lr_scheduler.optimizers[0].param_groups[0]['lr'] + logs["gsq_weighted"] = lr_scheduler.optimizers[0].param_groups[0]['gsq_weighted'] + accelerator.log(logs, step=global_step) # TODO moving averageにする diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 25a5b2d..a2d5b94 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -47,7 +47,7 @@ VGG( """ import json -from typing import List, Optional, Union +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable import glob import importlib import inspect @@ -60,7 +60,6 @@ import math import os import random import re -from typing import Any, Callable, List, Optional, Union import diffusers import numpy as np @@ -81,6 +80,8 @@ from PIL import Image from PIL.PngImagePlugin import PngInfo import library.model_util as model_util +import tools.original_control_net as original_control_net +from tools.original_control_net import ControlNetInfo # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -487,6 +488,9 @@ class PipelineLike(): self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers) self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) + # ControlNet + self.control_nets: List[ControlNetInfo] = [] + # Textual Inversion def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids @@ -500,7 +504,11 @@ class PipelineLike(): new_tokens.append(token) return new_tokens + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + # region xformersとか使う部分:独自に書き換えるので関係なし + def enable_xformers_memory_efficient_attention(self): r""" Enable memory efficient attention as implemented in xformers. @@ -752,7 +760,7 @@ class PipelineLike(): text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK - if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None: + if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets: if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] @@ -765,7 +773,7 @@ class PipelineLike(): image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) if len(image_embeddings_clip) == 1: image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1)) - else: + elif self.vgg16_guidance_scale > 0: size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?) clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images] clip_guide_images = torch.cat(clip_guide_images, dim=0) @@ -774,6 +782,10 @@ class PipelineLike(): image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat'] if len(image_embeddings_vgg16) == 1: image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1)) + else: + # ControlNetのhintにguide imageを流用する + # 前処理はControlNet側で行う + pass # set timesteps self.scheduler.set_timesteps(num_inference_steps, self.device) @@ -864,12 +876,21 @@ class PipelineLike(): extra_step_kwargs["eta"] = eta num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + if self.control_nets: + noise_pred = original_control_net.call_unet_and_control_net( + i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: @@ -1817,6 +1838,34 @@ def preprocess_mask(mask): # return text_encoder +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + + +class BatchData(NamedTuple): + base: BatchDataBase + ext: BatchDataExt + + def main(args): if args.fp16: dtype = torch.float16 @@ -1995,11 +2044,13 @@ def main(args): # networkを組み込む if args.network_module: networks = [] + network_default_muls = [] for i, network_module in enumerate(args.network_module): print("import network module:", network_module) imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_default_muls.append(network_mul) net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -2014,7 +2065,7 @@ def main(args): network_weight = args.network_weights[i] print("load network weights from:", network_weight) - if model_util.is_safetensors(network_weight): + if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() @@ -2037,6 +2088,18 @@ def main(args): else: networks = [] + # ControlNetの処理 + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + if args.opt_channels_last: print(f"set optimizing: channels last") text_encoder.to(memory_format=torch.channels_last) @@ -2050,9 +2113,14 @@ def main(args): if vgg16_model is not None: vgg16_model.to(memory_format=torch.channels_last) + for cn in control_nets: + cn.unet.to(memory_format=torch.channels_last) + cn.net.to(memory_format=torch.channels_last) + pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip, clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale, vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer) + pipe.set_control_nets(control_nets) print("pipeline is ready.") if args.diffusers_xformers: @@ -2186,9 +2254,12 @@ def main(args): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}") - guide_images = load_images(args.guide_image_path) - print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance") + print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + print(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") guide_images = None @@ -2219,33 +2290,37 @@ def main(args): iter_seed = random.randint(0, 0x7fffffff) # バッチ処理の関数 - def process_batch(batch, highres_fix, highres_1st=False): + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): batch_size = len(batch) # highres_fixの処理 if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す print("process 1st stage1") batch_1st = [] - for params1, (width, height, steps, scale, negative_scale, strength) in batch: - width_1st = int(width * args.highres_fix_scale + .5) - height_1st = int(height * args.highres_fix_scale + .5) + for base, ext in batch: + width_1st = int(ext.width * args.highres_fix_scale + .5) + height_1st = int(ext.height * args.highres_fix_scale + .5) width_1st = width_1st - width_1st % 32 height_1st = height_1st - height_1st % 32 - batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength))) + + ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale, + ext.negative_scale, ext.strength, ext.network_muls) + batch_1st.append(BatchData(base, ext_1st)) images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する print("process 2nd stage1") batch_2nd = [] - for i, (b1, image) in enumerate(zip(batch, images_1st)): - image = image.resize((width, height), resample=PIL.Image.LANCZOS) - (step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1 - batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2)) + for i, (bd, image) in enumerate(zip(batch, images_1st)): + image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 + bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) batch = batch_2nd - (step_first, _, _, _, init_image, mask_image, _, guide_image), (width, - height, steps, scale, negative_scale, strength) = batch[0] + # このバッチの情報を取り出す + (step_first, _, _, _, init_image, mask_image, _, guide_image), \ + (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) prompts = [] @@ -2295,9 +2370,13 @@ def main(args): all_masks_are_same = mask_images[-2] is mask_image if guide_image is not None: - guide_images.append(guide_image) - if i > 0 and all_guide_images_are_same: - all_guide_images_are_same = guide_images[-2] is guide_image + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image # make start code torch.manual_seed(seed) @@ -2320,7 +2399,19 @@ def main(args): if guide_images is not None and all_guide_images_are_same: guide_images = guide_images[0] + # ControlNet使用時はguide imageをリサイズする + if control_nets: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + # generate + if networks: + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] if highres_1st and not args.highres_fix_save_1st: @@ -2398,6 +2489,7 @@ def main(args): strength = 0.8 if args.strength is None else args.strength negative_prompt = "" clip_prompt = None + network_muls = None prompt_args = prompt.strip().split(' --') prompt = prompt_args[0] @@ -2461,6 +2553,15 @@ def main(args): clip_prompt = m.group(1) print(f"clip prompt: {clip_prompt}") continue + + m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) @@ -2498,7 +2599,12 @@ def main(args): mask_image = mask_images[global_step % len(mask_images)] if guide_images is not None: - guide_image = guide_images[global_step % len(guide_images)] + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c:p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: if prev_image is None: print("Generate 1st image without guide image.") @@ -2506,9 +2612,8 @@ def main(args): print("Use previous image as guide image.") guide_image = prev_image - # TODO named tupleか何かにする - b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), - (width, height, steps, scale, negative_scale, strength)) + b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None)) if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? process_batch(batch_data, highres_fix) batch_data.clear() @@ -2578,12 +2683,15 @@ if __name__ == '__main__': parser.add_argument("--opt_channels_last", action='store_true', help='set channels last option to model / モデルにchannels lastを指定し最適化する') parser.add_argument("--network_module", type=str, default=None, nargs='*', - help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') + help='additional network module to use / 追加ネットワークを使う時そのモジュール名') parser.add_argument("--network_weights", type=str, default=None, nargs='*', - help='Hypernetwork weights to load / Hypernetworkの重み') - parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') + help='additional network weights to load / 追加ネットワークの重み') + parser.add_argument("--network_mul", type=float, default=None, nargs='*', + help='additional network multiplier / 追加ネットワークの効果の倍率') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') + parser.add_argument("--network_show_meta", action='store_true', + help='show metadata of network model / ネットワークモデルのメタデータを表示する') parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*', help='Embeddings files of Textual Inversion / Textual Inversionのembeddings') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') @@ -2597,7 +2705,8 @@ if __name__ == '__main__': help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する') parser.add_argument("--vgg16_guidance_layer", type=int, default=20, help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)') - parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像") + parser.add_argument("--guide_image_path", type=str, default=None, nargs="*", + help="image to CLIP guidance / CLIP guided SDでガイドに使う画像") parser.add_argument("--highres_fix_scale", type=float, default=None, help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする") parser.add_argument("--highres_fix_steps", type=int, default=28, @@ -2607,5 +2716,13 @@ if __name__ == '__main__': parser.add_argument("--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する") + parser.add_argument("--control_net_models", type=str, default=None, nargs='*', + help='ControlNet models to use / 使用するControlNetのモデル名') + parser.add_argument("--control_net_preps", type=str, default=None, nargs='*', + help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名') + parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み') + parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*', + help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率') + args = parser.parse_args() main(args) diff --git a/library/train_util.py b/library/train_util.py index b1b9900..a02207b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1372,8 +1372,8 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--optimizer_type", type=str, default="AdamW", - help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor") + parser.add_argument("--optimizer_type", type=str, default="", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor") # backward compatibility parser.add_argument("--use_8bit_adam", action="store_true", @@ -1532,11 +1532,16 @@ def get_optimizer(args, trainable_params): optimizer_type = args.optimizer_type if args.use_8bit_adam: - print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます") + assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" + assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" optimizer_type = "AdamW8bit" + elif args.use_lion_optimizer: - print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます") + assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" optimizer_type = "Lion" + + if optimizer_type is None or optimizer_type == "": + optimizer_type = "AdamW" optimizer_type = optimizer_type.lower() # 引数を分解する:boolとfloat、tupleのみ対応 @@ -1557,7 +1562,7 @@ def get_optimizer(args, trainable_params): value = tuple(value) optimizer_kwargs[key] = value - print("optkwargs:", optimizer_kwargs) + # print("optkwargs:", optimizer_kwargs) lr = args.learning_rate @@ -1633,7 +1638,7 @@ def get_optimizer(args, trainable_params): if optimizer_kwargs["relative_step"]: print(f"relative_step is true / relative_stepがtrueです") if lr != 0.0: - print(f"learning rate is used as initial_lr / 指定したlearning rate はinitial_lrとして使用されます: {lr}") + print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") args.learning_rate = None # trainable_paramsがgroupだった時の処理:lrを削除する diff --git a/networks/lora.py b/networks/lora.py index a1f38c1..24b107b 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -126,6 +126,11 @@ class LoRANetwork(torch.nn.Module): assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + def load_weights(self, file): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import load_file, safe_open diff --git a/requirements.txt b/requirements.txt index bfbe8d9..6d2d6c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ altair==4.2.2 easygui==0.98.3 tk==0.1.0 lion-pytorch==0.0.6 +dadaptation==1.5 # for BLIP captioning requests==2.28.2 timm==0.6.12 diff --git a/tools/canny.py b/tools/canny.py new file mode 100644 index 0000000..2f01bbf --- /dev/null +++ b/tools/canny.py @@ -0,0 +1,24 @@ +import argparse +import cv2 + + +def canny(args): + img = cv2.imread(args.input) + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + canny_img = cv2.Canny(img, args.thres1, args.thres2) + # canny_img = 255 - canny_img + + cv2.imwrite(args.output, canny_img) + print("done!") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, default=None, help="input path") + parser.add_argument("--output", type=str, default=None, help="output path") + parser.add_argument("--thres1", type=int, default=32, help="thres1") + parser.add_argument("--thres2", type=int, default=224, help="thres2") + + args = parser.parse_args() + canny(args) diff --git a/tools/original_control_net.py b/tools/original_control_net.py new file mode 100644 index 0000000..4484ce9 --- /dev/null +++ b/tools/original_control_net.py @@ -0,0 +1,320 @@ +from typing import List, NamedTuple, Any +import numpy as np +import cv2 +import torch +from safetensors.torch import load_file + +from diffusers import UNet2DConditionModel +from diffusers.models.unet_2d_condition import UNet2DConditionOutput + +import library.model_util as model_util + + +class ControlNetInfo(NamedTuple): + unet: Any + net: Any + prep: Any + weight: float + ratio: float + + +class ControlNet(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + # make control model + self.control_model = torch.nn.Module() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] + zero_convs = torch.nn.ModuleList() + for i, dim in enumerate(dims): + sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) + zero_convs.append(sub_list) + self.control_model.add_module("zero_convs", zero_convs) + + middle_block_out = torch.nn.Conv2d(1280, 1280, 1) + self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) + + dims = [16, 16, 32, 32, 96, 96, 256, 320] + strides = [1, 1, 2, 1, 2, 1, 2, 1] + prev_dim = 3 + input_hint_block = torch.nn.Sequential() + for i, (dim, stride) in enumerate(zip(dims, strides)): + input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) + if i < len(dims) - 1: + input_hint_block.append(torch.nn.SiLU()) + prev_dim = dim + self.control_model.add_module("input_hint_block", input_hint_block) + + +def load_control_net(v2, unet, model): + device = unet.device + + # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む + # state dictを読み込む + print(f"ControlNet: loading control SD model : {model}") + + if model_util.is_safetensors(model): + ctrl_sd_sd = load_file(model) + else: + ctrl_sd_sd = torch.load(model, map_location='cpu') + ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) + + # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む + is_difference = "difference" in ctrl_sd_sd + print("ControlNet: loading difference") + + # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく + # またTransfer Controlの元weightとなる + ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) + + # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける + for key in list(ctrl_unet_sd_sd.keys()): + ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() + + zero_conv_sd = {} + for key in list(ctrl_sd_sd.keys()): + if key.startswith("control_"): + unet_key = "model.diffusion_" + key[len("control_"):] + if unet_key not in ctrl_unet_sd_sd: # zero conv + zero_conv_sd[key] = ctrl_sd_sd[key] + continue + if is_difference: # Transfer Control + ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) + else: + ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) + + unet_config = model_util.create_unet_diffusers_config(v2) + ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict + + # ControlNetのU-Netを作成する + ctrl_unet = UNet2DConditionModel(**unet_config) + info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) + print("ControlNet: loading Control U-Net:", info) + + # U-Net以外のControlNetを作成する + # TODO support middle only + ctrl_net = ControlNet() + info = ctrl_net.load_state_dict(zero_conv_sd) + print("ControlNet: loading ControlNet:", info) + + ctrl_unet.to(unet.device, dtype=unet.dtype) + ctrl_net.to(unet.device, dtype=unet.dtype) + return ctrl_unet, ctrl_net + + +def load_preprocess(prep_type: str): + if prep_type is None or prep_type.lower() == "none": + return None + + if prep_type.startswith("canny"): + args = prep_type.split("_") + th1 = int(args[1]) if len(args) >= 2 else 63 + th2 = int(args[2]) if len(args) >= 3 else 191 + + def canny(img): + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + return cv2.Canny(img, th1, th2) + return canny + + print("Unsupported prep type:", prep_type) + return None + + +def preprocess_ctrl_net_hint_image(image): + image = np.array(image).astype(np.float32) / 255.0 + image = image[:, :, ::-1].copy() # rgb to bgr + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 + + +def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints): + guided_hints = [] + for i, cnet_info in enumerate(control_nets): + # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること + b_hints = [] + if len(hints) == 1: # すべて同じ画像をhintとして使う + hint = hints[0] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints = [hint for _ in range(b_size)] + else: + for bi in range(b_size): + hint = hints[(bi * len(control_nets) + i) % len(hints)] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints.append(hint) + b_hints = torch.cat(b_hints, dim=0) + b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) + + guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) + guided_hints.append(guided_hint) + return guided_hints + + +def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states): + # ControlNet + # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する + cnet_cnt = len(control_nets) + cnet_idx = step % cnet_cnt + cnet_info = control_nets[cnet_idx] + + # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + if cnet_info.ratio < current_ratio: + return original_unet(sample, timestep, encoder_hidden_states) + + guided_hint = guided_hints[cnet_idx] + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) + outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) + outs = [o * cnet_info.weight for o in outs] + + # U-Net + return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) + + +""" + # これはmergeのバージョン + # ControlNet + cnet_outs_list = [] + for i, cnet_info in enumerate(control_nets): + # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + if cnet_info.ratio < current_ratio: + continue + guided_hint = guided_hints[i] + outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) + for i in range(len(outs)): + outs[i] *= cnet_info.weight + + cnet_outs_list.append(outs) + + count = len(cnet_outs_list) + if count == 0: + return original_unet(sample, timestep, encoder_hidden_states) + + # sum of controlnets + for i in range(1, count): + cnet_outs_list[0] += cnet_outs_list[i] + + # U-Net + return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states) +""" + + +def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states): + # copy from UNet2DConditionModel + default_overall_up_factor = 2**unet.num_upsamplers + + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + print("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if unet.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = unet.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=unet.dtype) + emb = unet.time_embedding(t_emb) + + outs = [] # output of ControlNet + zc_idx = 0 + + # 2. pre-process + sample = unet.conv_in(sample) + if is_control_net: + sample += guided_hint + outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) + zc_idx += 1 + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in unet.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_control_net: + for rs in res_samples: + outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) + zc_idx += 1 + + down_block_res_samples += res_samples + + # 4. mid + sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + if is_control_net: + outs.append(control_net.control_model.middle_block_out[0](sample)) + return outs + + if not is_control_net: + sample += ctrl_outs.pop() + + # 5. up + for i, upsample_block in enumerate(unet.up_blocks): + is_final_block = i == len(unet.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if not is_control_net and len(ctrl_outs) > 0: + res_samples = list(res_samples) + apply_ctrl_outs = ctrl_outs[-len(res_samples):] + ctrl_outs = ctrl_outs[:-len(res_samples)] + for j in range(len(res_samples)): + res_samples[j] = res_samples[j] + apply_ctrl_outs[j] + res_samples = tuple(res_samples) + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = unet.conv_norm_out(sample) + sample = unet.conv_act(sample) + sample = unet.conv_out(sample) + + return UNet2DConditionOutput(sample=sample) diff --git a/train_network.py b/train_network.py index d90aa19..0ba290a 100644 --- a/train_network.py +++ b/train_network.py @@ -36,8 +36,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. - logs["lr/d*lr-textencoder"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] - logs["lr/d*lr-unet"] = lr_scheduler.optimizers[-1].param_groups[1]['d']*lr_scheduler.optimizers[-1].param_groups[1]['lr'] + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] return logs @@ -276,9 +275,11 @@ def train(args): "ss_shuffle_caption": bool(args.shuffle_caption), "ss_cache_latents": bool(args.cache_latents), "ss_enable_bucket": bool(train_dataset.enable_bucket), + "ss_bucket_no_upscale": bool(train_dataset.bucket_no_upscale), "ss_min_bucket_reso": train_dataset.min_bucket_reso, "ss_max_bucket_reso": train_dataset.max_bucket_reso, "ss_seed": args.seed, + "ss_lowram": args.lowram, "ss_keep_tokens": args.keep_tokens, "ss_noise_offset": args.noise_offset, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), @@ -287,7 +288,13 @@ def train(args): "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment, # will not be updated after training "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), - "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else "") + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, } # uncomment if another network is added @@ -362,7 +369,7 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - with autocast(): + with accelerator.autocast(): noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: @@ -423,6 +430,7 @@ def train(args): def save_func(): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) + metadata["ss_training_finished_at"] = str(time.time()) print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) @@ -440,6 +448,7 @@ def train(args): # end of epoch metadata["ss_epoch"] = str(num_train_epochs) + metadata["ss_training_finished_at"] = str(time.time()) is_main_process = accelerator.is_main_process if is_main_process: