Save prompt file in sample directory

This commit is contained in:
bmaltais 2023-03-08 07:30:14 -05:00
parent 7ed8f7c3c5
commit 25d6e252d3
8 changed files with 228 additions and 64 deletions

View File

@ -473,6 +473,7 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
output_dir,
) )
print(run_cmd) print(run_cmd)

View File

@ -465,6 +465,7 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
output_dir,
) )
print(run_cmd) print(run_cmd)

View File

@ -1,4 +1,5 @@
import tempfile import tempfile
import os
import gradio as gr import gradio as gr
from easygui import msgbox from easygui import msgbox
@ -71,19 +72,23 @@ def run_cmd_sample(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
output_dir,
): ):
output_dir = os.path.join(output_dir, "sample")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
run_cmd = '' run_cmd = ''
if sample_every_n_epochs == 0 and sample_every_n_steps == 0: if sample_every_n_epochs == 0 and sample_every_n_steps == 0:
return run_cmd return run_cmd
# Create a temporary file and get its path # Create the prompt file and get its path
with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file: sample_prompts_path = os.path.join(output_dir, "prompt.txt")
# Write the contents of the variable to the file
temp_file.write(sample_prompts)
# Get the path of the temporary file with open(sample_prompts_path, 'w') as f:
sample_prompts_path = temp_file.name f.write(sample_prompts)
run_cmd += f' --sample_sampler={sample_sampler}' run_cmd += f' --sample_sampler={sample_sampler}'
run_cmd += f' --sample_prompts="{sample_prompts_path}"' run_cmd += f' --sample_prompts="{sample_prompts_path}"'

View File

@ -7,13 +7,13 @@ import re
import shutil import shutil
import time import time
from typing import ( from typing import (
Dict, Dict,
List, List,
NamedTuple, NamedTuple,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
Union, Union,
) )
from accelerate import Accelerator from accelerate import Accelerator
import glob import glob
@ -214,24 +214,24 @@ class AugHelper:
def __init__(self): def __init__(self):
# prepare all possible augmentators # prepare all possible augmentators
color_aug_method = albu.OneOf([ color_aug_method = albu.OneOf([
albu.HueSaturationValue(8, 0, 0, p=.5), albu.HueSaturationValue(8, 0, 0, p=.5),
albu.RandomGamma((95, 105), p=.5), albu.RandomGamma((95, 105), p=.5),
], p=.33) ], p=.33)
flip_aug_method = albu.HorizontalFlip(p=0.5) flip_aug_method = albu.HorizontalFlip(p=0.5)
# key: (use_color_aug, use_flip_aug) # key: (use_color_aug, use_flip_aug)
self.augmentors = { self.augmentors = {
(True, True): albu.Compose([ (True, True): albu.Compose([
color_aug_method, color_aug_method,
flip_aug_method, flip_aug_method,
], p=1.), ], p=1.),
(True, False): albu.Compose([ (True, False): albu.Compose([
color_aug_method, color_aug_method,
], p=1.), ], p=1.),
(False, True): albu.Compose([ (False, True): albu.Compose([
flip_aug_method, flip_aug_method,
], p=1.), ], p=1.),
(False, False): None (False, False): None
} }
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
@ -260,7 +260,7 @@ class DreamBoothSubset(BaseSubset):
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
self.is_reg = is_reg self.is_reg = is_reg
self.class_tokens = class_tokens self.class_tokens = class_tokens
@ -271,12 +271,13 @@ class DreamBoothSubset(BaseSubset):
return NotImplemented return NotImplemented
return self.image_dir == other.image_dir return self.image_dir == other.image_dir
class FineTuningSubset(BaseSubset): class FineTuningSubset(BaseSubset):
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
self.metadata_file = metadata_file self.metadata_file = metadata_file
@ -285,6 +286,7 @@ class FineTuningSubset(BaseSubset):
return NotImplemented return NotImplemented
return self.metadata_file == other.metadata_file return self.metadata_file == other.metadata_file
class BaseDataset(torch.utils.data.Dataset): class BaseDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None: def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
super().__init__() super().__init__()
@ -815,11 +817,13 @@ class DreamBoothDataset(BaseDataset):
reg_infos: List[ImageInfo] = [] reg_infos: List[ImageInfo] = []
for subset in subsets: for subset in subsets:
if subset.num_repeats < 1: if subset.num_repeats < 1:
print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") print(
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
continue continue
if subset in self.subsets: if subset in self.subsets:
print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") print(
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
continue continue
img_paths, captions = load_dreambooth_dir(subset) img_paths, captions = load_dreambooth_dir(subset)
@ -881,11 +885,13 @@ class FineTuningDataset(BaseDataset):
for subset in subsets: for subset in subsets:
if subset.num_repeats < 1: if subset.num_repeats < 1:
print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") print(
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
continue continue
if subset in self.subsets: if subset in self.subsets:
print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") print(
f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
continue continue
# メタデータを読み込む # メタデータを読み込む
@ -937,7 +943,7 @@ class FineTuningDataset(BaseDataset):
self.subsets.append(subset) self.subsets.append(subset)
# check existence of all npz files # check existence of all npz files
use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets]) use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
if use_npz_latents: if use_npz_latents:
flip_aug_in_subset = False flip_aug_in_subset = False
npz_any = False npz_any = False
@ -2209,8 +2215,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return return
# ここでCUDAのキャッシュクリアとかしたほうがいいのか……
org_vae_device = vae.device # CPUにいるはず org_vae_device = vae.device # CPUにいるはず
vae.to(device) vae.to(device)
@ -2356,6 +2360,10 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
image.save(os.path.join(save_dir, img_filename)) image.save(os.path.join(save_dir, img_filename))
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device) vae.to(org_vae_device)

View File

@ -483,7 +483,12 @@ def train_model(
run_cmd += ( run_cmd += (
f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"' f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
) )
else: if LoRA_type == 'Kohya LoCon':
run_cmd += f' --network_module=networks.lora'
run_cmd += (
f' --network_args "conv_lora_dim={conv_dim}" "conv_alpha={conv_alpha}"'
)
if LoRA_type == 'Standard':
run_cmd += f' --network_module=networks.lora' run_cmd += f' --network_module=networks.lora'
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
@ -563,6 +568,7 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
output_dir,
) )
print(run_cmd) print(run_cmd)
@ -687,8 +693,9 @@ def lora_tab(
LoRA_type = gr.Dropdown( LoRA_type = gr.Dropdown(
label='LoRA type', label='LoRA type',
choices=[ choices=[
'Standard', 'Kohya LoCon',
'LoCon', 'LoCon',
'Standard',
], ],
value='Standard', value='Standard',
) )
@ -774,7 +781,7 @@ def lora_tab(
# Show of hide LoCon conv settings depending on LoRA type selection # Show of hide LoCon conv settings depending on LoRA type selection
def LoRA_type_change(LoRA_type): def LoRA_type_change(LoRA_type):
print('LoRA type changed...') print('LoRA type changed...')
if LoRA_type == 'LoCon': if LoRA_type == 'LoCon' or LoRA_type == 'Kohya LoCon':
return gr.Group.update(visible=True) return gr.Group.update(visible=True)
else: else:
return gr.Group.update(visible=False) return gr.Group.update(visible=False)

View File

@ -6,6 +6,7 @@
import math import math
import os import os
from typing import List from typing import List
import numpy as np
import torch import torch
from library import train_util from library import train_util
@ -25,8 +26,16 @@ class LoRAModule(torch.nn.Module):
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels in_dim = org_module.in_channels
out_dim = org_module.out_channels out_dim = org_module.out_channels
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) self.lora_dim = min(self.lora_dim, in_dim, out_dim)
if self.lora_dim != lora_dim:
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else: else:
in_dim = org_module.in_features in_dim = org_module.in_features
out_dim = org_module.out_features out_dim = org_module.out_features
@ -45,20 +54,94 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier self.multiplier = multiplier
self.org_module = org_module # remove in applying self.org_module = org_module # remove in applying
self.region = None
self.region_mask = None
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module.forward self.org_forward = self.org_module.forward
self.org_module.forward = self.forward self.org_module.forward = self.forward
del self.org_module del self.org_module
def set_region(self, region):
self.region = region
self.region_mask = None
def forward(self, x): def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale if self.region is None:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# reginal LoRA
if x.size()[1] % 77 == 0:
# print(f"LoRA for context: {self.lora_name}")
self.region = None
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# calculate region mask first time
if self.region_mask is None:
if len(x.size()) == 4:
h, w = x.size()[2:4]
else:
seq_len = x.size()[1]
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
h = int(self.region.size()[0] / ratio + .5)
w = seq_len // h
r = self.region.to(x.device)
if r.dtype == torch.bfloat16:
r = r.to(torch.float)
r = r.unsqueeze(0).unsqueeze(1)
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
r = r.to(x.dtype)
if len(x.size()) == 3:
r = torch.reshape(r, (1, x.size()[1], -1))
self.region_mask = r
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
# extract dim/alpha for conv2d, and block dim
conv_dim = int(kwargs.get('conv_dim', network_dim))
conv_alpha = kwargs.get('conv_alpha', network_alpha)
if conv_alpha is not None:
conv_alpha = float(conv_alpha)
"""
block_dims = kwargs.get("block_dims")
block_alphas = None
if block_dims is not None:
block_dims = [int(d) for d in block_dims.split(',')]
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
block_alphas = kwargs.get("block_alphas")
if block_alphas is None:
block_alphas = [1] * len(block_dims)
else:
block_alphas = [int(a) for a in block_alphas(',')]
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
conv_block_dims = kwargs.get("conv_block_dims")
conv_block_alphas = None
if conv_block_dims is not None:
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
conv_block_alphas = kwargs.get("conv_block_alphas")
if conv_block_alphas is None:
conv_block_alphas = [1] * len(conv_block_dims)
else:
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
"""
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
return network return network
@ -69,45 +152,88 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa
else: else:
weights_sd = torch.load(file, map_location='cpu') weights_sd = torch.load(file, map_location='cpu')
# get dim (rank) # get dim/alpha mapping
network_alpha = None modules_dim = {}
network_dim = None modules_alpha = {}
for key, value in weights_sd.items(): for key, value in weights_sd.items():
if network_alpha is None and 'alpha' in key: if '.' not in key:
network_alpha = value continue
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is None: lora_name = key.split('.')[0]
network_alpha = network_dim if 'alpha' in key:
modules_alpha[lora_name] = value
elif 'lora_down' in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
print(lora_name, value.size(), dim)
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) # support old LoRA without alpha
for key in modules_dim.keys():
if key not in modules_alpha:
modules_alpha = modules_dim[key]
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
network.weights_sd = weights_sd network.weights_sd = weights_sd
return network return network
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] # is it possible to apply conv_in and conv_out?
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = 'lora_te'
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None: def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None:
super().__init__() super().__init__()
self.multiplier = multiplier self.multiplier = multiplier
self.lora_dim = lora_dim self.lora_dim = lora_dim
self.alpha = alpha self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
if modules_dim is not None:
print(f"create LoRA network from weights")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
if self.apply_to_conv2d_3x3:
if self.conv_alpha is None:
self.conv_alpha = self.alpha
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances # create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
loras = [] loras = []
for name, module in root_module.named_modules(): for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules: if module.__class__.__name__ in target_replace_modules:
# TODO get block index here
for child_name, child_module in module.named_modules(): for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + '.' + name + '.' + child_name lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_') lora_name = lora_name.replace('.', '_')
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
if modules_dim is not None:
if lora_name not in modules_dim:
continue # no LoRA module in this weights file
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
alpha = self.alpha
elif self.apply_to_conv2d_3x3:
dim = self.conv_lora_dim
alpha = self.conv_alpha
else:
continue
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
loras.append(lora) loras.append(lora)
return loras return loras
@ -240,3 +366,18 @@ class LoRANetwork(torch.nn.Module):
save_file(state_dict, file, metadata) save_file(state_dict, file, metadata)
else: else:
torch.save(state_dict, file) torch.save(state_dict, file)
@staticmethod
def set_regions(networks, image):
image = image.astype(np.float32) / 255.0
for i, network in enumerate(networks[:3]):
# NOTE: consider averaging overwrapping area
region = image[:, :, i]
if region.max() == 0:
continue
region = torch.tensor(region)
network.set_region(region)
def set_region(self, region):
for lora in self.unet_loras:
lora.set_region(region)

View File

@ -515,6 +515,7 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
output_dir,
) )
print(run_cmd) print(run_cmd)

View File

@ -427,9 +427,9 @@ def train(args):
"ss_bucket_info": json.dumps(dataset.bucket_info), "ss_bucket_info": json.dumps(dataset.bucket_info),
}) })
# uncomment if another network is added if args.network_args:
# for key, value in net_kwargs.items(): for key, value in net_kwargs.items():
# metadata["ss_arg_" + key] = value metadata["ss_arg_" + key] = value
if args.pretrained_model_name_or_path is not None: if args.pretrained_model_name_or_path is not None:
sd_model_name = args.pretrained_model_name_or_path sd_model_name = args.pretrained_model_name_or_path