From 25d6e252d33127c88baf19b8aeec1f67cacbe656 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Wed, 8 Mar 2023 07:30:14 -0500 Subject: [PATCH] Save prompt file in sample directory --- dreambooth_gui.py | 1 + finetune_gui.py | 1 + library/sampler_gui.py | 17 ++-- library/train_util.py | 72 +++++++++------- lora_gui.py | 13 ++- networks/lora.py | 179 ++++++++++++++++++++++++++++++++++----- textual_inversion_gui.py | 1 + train_network.py | 8 +- 8 files changed, 228 insertions(+), 64 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 90fedd2..df40784 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -473,6 +473,7 @@ def train_model( sample_every_n_epochs, sample_sampler, sample_prompts, + output_dir, ) print(run_cmd) diff --git a/finetune_gui.py b/finetune_gui.py index 543f50c..1daa2d3 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -465,6 +465,7 @@ def train_model( sample_every_n_epochs, sample_sampler, sample_prompts, + output_dir, ) print(run_cmd) diff --git a/library/sampler_gui.py b/library/sampler_gui.py index 7e94a15..7a7734f 100644 --- a/library/sampler_gui.py +++ b/library/sampler_gui.py @@ -1,4 +1,5 @@ import tempfile +import os import gradio as gr from easygui import msgbox @@ -71,19 +72,23 @@ def run_cmd_sample( sample_every_n_epochs, sample_sampler, sample_prompts, + output_dir, ): + output_dir = os.path.join(output_dir, "sample") + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + run_cmd = '' if sample_every_n_epochs == 0 and sample_every_n_steps == 0: return run_cmd - # Create a temporary file and get its path - with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file: - # Write the contents of the variable to the file - temp_file.write(sample_prompts) + # Create the prompt file and get its path + sample_prompts_path = os.path.join(output_dir, "prompt.txt") - # Get the path of the temporary file - sample_prompts_path = temp_file.name + with open(sample_prompts_path, 'w') as f: + f.write(sample_prompts) run_cmd += f' --sample_sampler={sample_sampler}' run_cmd += f' --sample_prompts="{sample_prompts_path}"' diff --git a/library/train_util.py b/library/train_util.py index 75176e1..351d222 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -7,13 +7,13 @@ import re import shutil import time from typing import ( - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union, + Dict, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, ) from accelerate import Accelerator import glob @@ -214,24 +214,24 @@ class AugHelper: def __init__(self): # prepare all possible augmentators color_aug_method = albu.OneOf([ - albu.HueSaturationValue(8, 0, 0, p=.5), - albu.RandomGamma((95, 105), p=.5), + albu.HueSaturationValue(8, 0, 0, p=.5), + albu.RandomGamma((95, 105), p=.5), ], p=.33) flip_aug_method = albu.HorizontalFlip(p=0.5) # key: (use_color_aug, use_flip_aug) self.augmentors = { - (True, True): albu.Compose([ - color_aug_method, - flip_aug_method, - ], p=1.), - (True, False): albu.Compose([ - color_aug_method, - ], p=1.), - (False, True): albu.Compose([ - flip_aug_method, - ], p=1.), - (False, False): None + (True, True): albu.Compose([ + color_aug_method, + flip_aug_method, + ], p=1.), + (True, False): albu.Compose([ + color_aug_method, + ], p=1.), + (False, True): albu.Compose([ + flip_aug_method, + ], p=1.), + (False, False): None } def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: @@ -260,7 +260,7 @@ class DreamBoothSubset(BaseSubset): assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, - face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) self.is_reg = is_reg self.class_tokens = class_tokens @@ -271,12 +271,13 @@ class DreamBoothSubset(BaseSubset): return NotImplemented return self.image_dir == other.image_dir + class FineTuningSubset(BaseSubset): def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, - face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) self.metadata_file = metadata_file @@ -285,6 +286,7 @@ class FineTuningSubset(BaseSubset): return NotImplemented return self.metadata_file == other.metadata_file + class BaseDataset(torch.utils.data.Dataset): def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None: super().__init__() @@ -804,7 +806,7 @@ class DreamBoothDataset(BaseDataset): captions.append("") else: captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) - + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 return img_paths, captions @@ -815,11 +817,13 @@ class DreamBoothDataset(BaseDataset): reg_infos: List[ImageInfo] = [] for subset in subsets: if subset.num_repeats < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") + print( + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") continue if subset in self.subsets: - print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") + print( + f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") continue img_paths, captions = load_dreambooth_dir(subset) @@ -881,11 +885,13 @@ class FineTuningDataset(BaseDataset): for subset in subsets: if subset.num_repeats < 1: - print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") + print( + f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") continue if subset in self.subsets: - print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") + print( + f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") continue # メタデータを読み込む @@ -937,7 +943,7 @@ class FineTuningDataset(BaseDataset): self.subsets.append(subset) # check existence of all npz files - use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets]) + use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) if use_npz_latents: flip_aug_in_subset = False npz_any = False @@ -2209,8 +2215,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return - # ここでCUDAのキャッシュクリアとかしたほうがいいのか…… - org_vae_device = vae.device # CPUにいるはず vae.to(device) @@ -2346,7 +2350,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - + image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) @@ -2356,6 +2360,10 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v image.save(os.path.join(save_dir, img_filename)) + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + torch.set_rng_state(rng_state) torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) @@ -2386,4 +2394,4 @@ class ImageLoadingDataset(torch.utils.data.Dataset): return (tensor_pil, img_path) -# endregion +# endregion \ No newline at end of file diff --git a/lora_gui.py b/lora_gui.py index 80122a7..6aa0d96 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -483,7 +483,12 @@ def train_model( run_cmd += ( f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"' ) - else: + if LoRA_type == 'Kohya LoCon': + run_cmd += f' --network_module=networks.lora' + run_cmd += ( + f' --network_args "conv_lora_dim={conv_dim}" "conv_alpha={conv_alpha}"' + ) + if LoRA_type == 'Standard': run_cmd += f' --network_module=networks.lora' if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): @@ -563,6 +568,7 @@ def train_model( sample_every_n_epochs, sample_sampler, sample_prompts, + output_dir, ) print(run_cmd) @@ -687,8 +693,9 @@ def lora_tab( LoRA_type = gr.Dropdown( label='LoRA type', choices=[ - 'Standard', + 'Kohya LoCon', 'LoCon', + 'Standard', ], value='Standard', ) @@ -774,7 +781,7 @@ def lora_tab( # Show of hide LoCon conv settings depending on LoRA type selection def LoRA_type_change(LoRA_type): print('LoRA type changed...') - if LoRA_type == 'LoCon': + if LoRA_type == 'LoCon' or LoRA_type == 'Kohya LoCon': return gr.Group.update(visible=True) else: return gr.Group.update(visible=False) diff --git a/networks/lora.py b/networks/lora.py index 24b107b..7179baf 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -6,6 +6,7 @@ import math import os from typing import List +import numpy as np import torch from library import train_util @@ -25,8 +26,16 @@ class LoRAModule(torch.nn.Module): if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels out_dim = org_module.out_channels - self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) - self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) + + self.lora_dim = min(self.lora_dim, in_dim, out_dim) + if self.lora_dim != lora_dim: + print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: in_dim = org_module.in_features out_dim = org_module.out_features @@ -45,20 +54,94 @@ class LoRAModule(torch.nn.Module): self.multiplier = multiplier self.org_module = org_module # remove in applying + self.region = None + self.region_mask = None def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward del self.org_module + def set_region(self, region): + self.region = region + self.region_mask = None + def forward(self, x): - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + if self.region is None: + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + # reginal LoRA + if x.size()[1] % 77 == 0: + # print(f"LoRA for context: {self.lora_name}") + self.region = None + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + # calculate region mask first time + if self.region_mask is None: + if len(x.size()) == 4: + h, w = x.size()[2:4] + else: + seq_len = x.size()[1] + ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) + h = int(self.region.size()[0] / ratio + .5) + w = seq_len // h + + r = self.region.to(x.device) + if r.dtype == torch.bfloat16: + r = r.to(torch.float) + r = r.unsqueeze(0).unsqueeze(1) + # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w) + r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear') + r = r.to(x.dtype) + + if len(x.size()) == 3: + r = torch.reshape(r, (1, x.size()[1], -1)) + + self.region_mask = r + + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): if network_dim is None: network_dim = 4 # default - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + + # extract dim/alpha for conv2d, and block dim + conv_dim = int(kwargs.get('conv_dim', network_dim)) + conv_alpha = kwargs.get('conv_alpha', network_alpha) + if conv_alpha is not None: + conv_alpha = float(conv_alpha) + + """ + block_dims = kwargs.get("block_dims") + block_alphas = None + + if block_dims is not None: + block_dims = [int(d) for d in block_dims.split(',')] + assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" + block_alphas = kwargs.get("block_alphas") + if block_alphas is None: + block_alphas = [1] * len(block_dims) + else: + block_alphas = [int(a) for a in block_alphas(',')] + assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" + + conv_block_dims = kwargs.get("conv_block_dims") + conv_block_alphas = None + + if conv_block_dims is not None: + conv_block_dims = [int(d) for d in conv_block_dims.split(',')] + assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" + conv_block_alphas = kwargs.get("conv_block_alphas") + if conv_block_alphas is None: + conv_block_alphas = [1] * len(conv_block_dims) + else: + conv_block_alphas = [int(a) for a in conv_block_alphas(',')] + assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" + """ + + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, + alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha) return network @@ -69,45 +152,88 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa else: weights_sd = torch.load(file, map_location='cpu') - # get dim (rank) - network_alpha = None - network_dim = None + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} for key, value in weights_sd.items(): - if network_alpha is None and 'alpha' in key: - network_alpha = value - if network_dim is None and 'lora_down' in key and len(value.size()) == 2: - network_dim = value.size()[0] + if '.' not in key: + continue - if network_alpha is None: - network_alpha = network_dim + lora_name = key.split('.')[0] + if 'alpha' in key: + modules_alpha[lora_name] = value + elif 'lora_down' in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + print(lora_name, value.size(), dim) - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha = modules_dim[key] + + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) network.weights_sd = weights_sd return network class LoRANetwork(torch.nn.Module): - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + # is it possible to apply conv_in and conv_out? + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' - def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None: + def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None: super().__init__() self.multiplier = multiplier + self.lora_dim = lora_dim self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + + if modules_dim is not None: + print(f"create LoRA network from weights") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + + self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None + if self.apply_to_conv2d_3x3: + if self.conv_alpha is None: + self.conv_alpha = self.alpha + print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: loras = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: + # TODO get block index here for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + if is_linear or is_conv2d: lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') - lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha) + + if modules_dim is not None: + if lora_name not in modules_dim: + continue # no LoRA module in this weights file + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.apply_to_conv2d_3x3: + dim = self.conv_lora_dim + alpha = self.conv_alpha + else: + continue + + lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) loras.append(lora) return loras @@ -130,7 +256,7 @@ class LoRANetwork(torch.nn.Module): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier - + def load_weights(self, file): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import load_file, safe_open @@ -240,3 +366,18 @@ class LoRANetwork(torch.nn.Module): save_file(state_dict, file, metadata) else: torch.save(state_dict, file) + + @staticmethod + def set_regions(networks, image): + image = image.astype(np.float32) / 255.0 + for i, network in enumerate(networks[:3]): + # NOTE: consider averaging overwrapping area + region = image[:, :, i] + if region.max() == 0: + continue + region = torch.tensor(region) + network.set_region(region) + + def set_region(self, region): + for lora in self.unet_loras: + lora.set_region(region) \ No newline at end of file diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index c2fc740..fa8f53e 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -515,6 +515,7 @@ def train_model( sample_every_n_epochs, sample_sampler, sample_prompts, + output_dir, ) print(run_cmd) diff --git a/train_network.py b/train_network.py index 387590b..4d5ffd3 100644 --- a/train_network.py +++ b/train_network.py @@ -427,9 +427,9 @@ def train(args): "ss_bucket_info": json.dumps(dataset.bucket_info), }) - # uncomment if another network is added - # for key, value in net_kwargs.items(): - # metadata["ss_arg_" + key] = value + if args.network_args: + for key, value in net_kwargs.items(): + metadata["ss_arg_" + key] = value if args.pretrained_model_name_or_path is not None: sd_model_name = args.pretrained_model_name_or_path @@ -639,4 +639,4 @@ if __name__ == '__main__': help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") args = parser.parse_args() - train(args) + train(args) \ No newline at end of file