Save prompt file in sample directory
This commit is contained in:
parent
7ed8f7c3c5
commit
25d6e252d3
@ -473,6 +473,7 @@ def train_model(
|
|||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
|
output_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
|
@ -465,6 +465,7 @@ def train_model(
|
|||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
|
output_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from easygui import msgbox
|
from easygui import msgbox
|
||||||
|
|
||||||
@ -71,19 +72,23 @@ def run_cmd_sample(
|
|||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
|
output_dir,
|
||||||
):
|
):
|
||||||
|
output_dir = os.path.join(output_dir, "sample")
|
||||||
|
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
run_cmd = ''
|
run_cmd = ''
|
||||||
|
|
||||||
if sample_every_n_epochs == 0 and sample_every_n_steps == 0:
|
if sample_every_n_epochs == 0 and sample_every_n_steps == 0:
|
||||||
return run_cmd
|
return run_cmd
|
||||||
|
|
||||||
# Create a temporary file and get its path
|
# Create the prompt file and get its path
|
||||||
with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file:
|
sample_prompts_path = os.path.join(output_dir, "prompt.txt")
|
||||||
# Write the contents of the variable to the file
|
|
||||||
temp_file.write(sample_prompts)
|
|
||||||
|
|
||||||
# Get the path of the temporary file
|
with open(sample_prompts_path, 'w') as f:
|
||||||
sample_prompts_path = temp_file.name
|
f.write(sample_prompts)
|
||||||
|
|
||||||
run_cmd += f' --sample_sampler={sample_sampler}'
|
run_cmd += f' --sample_sampler={sample_sampler}'
|
||||||
run_cmd += f' --sample_prompts="{sample_prompts_path}"'
|
run_cmd += f' --sample_prompts="{sample_prompts_path}"'
|
||||||
|
@ -7,13 +7,13 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
import glob
|
import glob
|
||||||
@ -214,24 +214,24 @@ class AugHelper:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# prepare all possible augmentators
|
# prepare all possible augmentators
|
||||||
color_aug_method = albu.OneOf([
|
color_aug_method = albu.OneOf([
|
||||||
albu.HueSaturationValue(8, 0, 0, p=.5),
|
albu.HueSaturationValue(8, 0, 0, p=.5),
|
||||||
albu.RandomGamma((95, 105), p=.5),
|
albu.RandomGamma((95, 105), p=.5),
|
||||||
], p=.33)
|
], p=.33)
|
||||||
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
||||||
|
|
||||||
# key: (use_color_aug, use_flip_aug)
|
# key: (use_color_aug, use_flip_aug)
|
||||||
self.augmentors = {
|
self.augmentors = {
|
||||||
(True, True): albu.Compose([
|
(True, True): albu.Compose([
|
||||||
color_aug_method,
|
color_aug_method,
|
||||||
flip_aug_method,
|
flip_aug_method,
|
||||||
], p=1.),
|
], p=1.),
|
||||||
(True, False): albu.Compose([
|
(True, False): albu.Compose([
|
||||||
color_aug_method,
|
color_aug_method,
|
||||||
], p=1.),
|
], p=1.),
|
||||||
(False, True): albu.Compose([
|
(False, True): albu.Compose([
|
||||||
flip_aug_method,
|
flip_aug_method,
|
||||||
], p=1.),
|
], p=1.),
|
||||||
(False, False): None
|
(False, False): None
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
||||||
@ -260,7 +260,7 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||||
|
|
||||||
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
||||||
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
||||||
|
|
||||||
self.is_reg = is_reg
|
self.is_reg = is_reg
|
||||||
self.class_tokens = class_tokens
|
self.class_tokens = class_tokens
|
||||||
@ -271,12 +271,13 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.image_dir == other.image_dir
|
return self.image_dir == other.image_dir
|
||||||
|
|
||||||
|
|
||||||
class FineTuningSubset(BaseSubset):
|
class FineTuningSubset(BaseSubset):
|
||||||
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
||||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||||
|
|
||||||
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
||||||
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
||||||
|
|
||||||
self.metadata_file = metadata_file
|
self.metadata_file = metadata_file
|
||||||
|
|
||||||
@ -285,6 +286,7 @@ class FineTuningSubset(BaseSubset):
|
|||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.metadata_file == other.metadata_file
|
return self.metadata_file == other.metadata_file
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(torch.utils.data.Dataset):
|
class BaseDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -804,7 +806,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
captions.append("")
|
captions.append("")
|
||||||
else:
|
else:
|
||||||
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
||||||
|
|
||||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||||
|
|
||||||
return img_paths, captions
|
return img_paths, captions
|
||||||
@ -815,11 +817,13 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
reg_infos: List[ImageInfo] = []
|
reg_infos: List[ImageInfo] = []
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
if subset.num_repeats < 1:
|
if subset.num_repeats < 1:
|
||||||
print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
print(
|
||||||
|
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if subset in self.subsets:
|
if subset in self.subsets:
|
||||||
print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
print(
|
||||||
|
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
img_paths, captions = load_dreambooth_dir(subset)
|
img_paths, captions = load_dreambooth_dir(subset)
|
||||||
@ -881,11 +885,13 @@ class FineTuningDataset(BaseDataset):
|
|||||||
|
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
if subset.num_repeats < 1:
|
if subset.num_repeats < 1:
|
||||||
print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
print(
|
||||||
|
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if subset in self.subsets:
|
if subset in self.subsets:
|
||||||
print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
print(
|
||||||
|
f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# メタデータを読み込む
|
# メタデータを読み込む
|
||||||
@ -937,7 +943,7 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.subsets.append(subset)
|
self.subsets.append(subset)
|
||||||
|
|
||||||
# check existence of all npz files
|
# check existence of all npz files
|
||||||
use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets])
|
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||||
if use_npz_latents:
|
if use_npz_latents:
|
||||||
flip_aug_in_subset = False
|
flip_aug_in_subset = False
|
||||||
npz_any = False
|
npz_any = False
|
||||||
@ -2209,8 +2215,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# ここでCUDAのキャッシュクリアとかしたほうがいいのか……
|
|
||||||
|
|
||||||
org_vae_device = vae.device # CPUにいるはず
|
org_vae_device = vae.device # CPUにいるはず
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
@ -2346,7 +2350,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
|
||||||
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
||||||
|
|
||||||
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||||
@ -2356,6 +2360,10 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
|
|
||||||
image.save(os.path.join(save_dir, img_filename))
|
image.save(os.path.join(save_dir, img_filename))
|
||||||
|
|
||||||
|
# clear pipeline and cache to reduce vram usage
|
||||||
|
del pipeline
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
vae.to(org_vae_device)
|
vae.to(org_vae_device)
|
||||||
@ -2386,4 +2394,4 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
|||||||
return (tensor_pil, img_path)
|
return (tensor_pil, img_path)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
13
lora_gui.py
13
lora_gui.py
@ -483,7 +483,12 @@ def train_model(
|
|||||||
run_cmd += (
|
run_cmd += (
|
||||||
f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
|
f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
|
||||||
)
|
)
|
||||||
else:
|
if LoRA_type == 'Kohya LoCon':
|
||||||
|
run_cmd += f' --network_module=networks.lora'
|
||||||
|
run_cmd += (
|
||||||
|
f' --network_args "conv_lora_dim={conv_dim}" "conv_alpha={conv_alpha}"'
|
||||||
|
)
|
||||||
|
if LoRA_type == 'Standard':
|
||||||
run_cmd += f' --network_module=networks.lora'
|
run_cmd += f' --network_module=networks.lora'
|
||||||
|
|
||||||
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
||||||
@ -563,6 +568,7 @@ def train_model(
|
|||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
|
output_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -687,8 +693,9 @@ def lora_tab(
|
|||||||
LoRA_type = gr.Dropdown(
|
LoRA_type = gr.Dropdown(
|
||||||
label='LoRA type',
|
label='LoRA type',
|
||||||
choices=[
|
choices=[
|
||||||
'Standard',
|
'Kohya LoCon',
|
||||||
'LoCon',
|
'LoCon',
|
||||||
|
'Standard',
|
||||||
],
|
],
|
||||||
value='Standard',
|
value='Standard',
|
||||||
)
|
)
|
||||||
@ -774,7 +781,7 @@ def lora_tab(
|
|||||||
# Show of hide LoCon conv settings depending on LoRA type selection
|
# Show of hide LoCon conv settings depending on LoRA type selection
|
||||||
def LoRA_type_change(LoRA_type):
|
def LoRA_type_change(LoRA_type):
|
||||||
print('LoRA type changed...')
|
print('LoRA type changed...')
|
||||||
if LoRA_type == 'LoCon':
|
if LoRA_type == 'LoCon' or LoRA_type == 'Kohya LoCon':
|
||||||
return gr.Group.update(visible=True)
|
return gr.Group.update(visible=True)
|
||||||
else:
|
else:
|
||||||
return gr.Group.update(visible=False)
|
return gr.Group.update(visible=False)
|
||||||
|
179
networks/lora.py
179
networks/lora.py
@ -6,6 +6,7 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from library import train_util
|
from library import train_util
|
||||||
@ -25,8 +26,16 @@ class LoRAModule(torch.nn.Module):
|
|||||||
if org_module.__class__.__name__ == 'Conv2d':
|
if org_module.__class__.__name__ == 'Conv2d':
|
||||||
in_dim = org_module.in_channels
|
in_dim = org_module.in_channels
|
||||||
out_dim = org_module.out_channels
|
out_dim = org_module.out_channels
|
||||||
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
|
||||||
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
|
||||||
|
if self.lora_dim != lora_dim:
|
||||||
|
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||||
|
|
||||||
|
kernel_size = org_module.kernel_size
|
||||||
|
stride = org_module.stride
|
||||||
|
padding = org_module.padding
|
||||||
|
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||||
|
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||||
else:
|
else:
|
||||||
in_dim = org_module.in_features
|
in_dim = org_module.in_features
|
||||||
out_dim = org_module.out_features
|
out_dim = org_module.out_features
|
||||||
@ -45,20 +54,94 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.org_module = org_module # remove in applying
|
self.org_module = org_module # remove in applying
|
||||||
|
self.region = None
|
||||||
|
self.region_mask = None
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
self.org_module.forward = self.forward
|
self.org_module.forward = self.forward
|
||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
|
def set_region(self, region):
|
||||||
|
self.region = region
|
||||||
|
self.region_mask = None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
if self.region is None:
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
# reginal LoRA
|
||||||
|
if x.size()[1] % 77 == 0:
|
||||||
|
# print(f"LoRA for context: {self.lora_name}")
|
||||||
|
self.region = None
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
# calculate region mask first time
|
||||||
|
if self.region_mask is None:
|
||||||
|
if len(x.size()) == 4:
|
||||||
|
h, w = x.size()[2:4]
|
||||||
|
else:
|
||||||
|
seq_len = x.size()[1]
|
||||||
|
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
||||||
|
h = int(self.region.size()[0] / ratio + .5)
|
||||||
|
w = seq_len // h
|
||||||
|
|
||||||
|
r = self.region.to(x.device)
|
||||||
|
if r.dtype == torch.bfloat16:
|
||||||
|
r = r.to(torch.float)
|
||||||
|
r = r.unsqueeze(0).unsqueeze(1)
|
||||||
|
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
||||||
|
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
|
||||||
|
r = r.to(x.dtype)
|
||||||
|
|
||||||
|
if len(x.size()) == 3:
|
||||||
|
r = torch.reshape(r, (1, x.size()[1], -1))
|
||||||
|
|
||||||
|
self.region_mask = r
|
||||||
|
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
network_dim = 4 # default
|
network_dim = 4 # default
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
|
||||||
|
# extract dim/alpha for conv2d, and block dim
|
||||||
|
conv_dim = int(kwargs.get('conv_dim', network_dim))
|
||||||
|
conv_alpha = kwargs.get('conv_alpha', network_alpha)
|
||||||
|
if conv_alpha is not None:
|
||||||
|
conv_alpha = float(conv_alpha)
|
||||||
|
|
||||||
|
"""
|
||||||
|
block_dims = kwargs.get("block_dims")
|
||||||
|
block_alphas = None
|
||||||
|
|
||||||
|
if block_dims is not None:
|
||||||
|
block_dims = [int(d) for d in block_dims.split(',')]
|
||||||
|
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
||||||
|
block_alphas = kwargs.get("block_alphas")
|
||||||
|
if block_alphas is None:
|
||||||
|
block_alphas = [1] * len(block_dims)
|
||||||
|
else:
|
||||||
|
block_alphas = [int(a) for a in block_alphas(',')]
|
||||||
|
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
||||||
|
|
||||||
|
conv_block_dims = kwargs.get("conv_block_dims")
|
||||||
|
conv_block_alphas = None
|
||||||
|
|
||||||
|
if conv_block_dims is not None:
|
||||||
|
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
||||||
|
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
||||||
|
conv_block_alphas = kwargs.get("conv_block_alphas")
|
||||||
|
if conv_block_alphas is None:
|
||||||
|
conv_block_alphas = [1] * len(conv_block_dims)
|
||||||
|
else:
|
||||||
|
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
||||||
|
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
||||||
|
"""
|
||||||
|
|
||||||
|
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim,
|
||||||
|
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha)
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
@ -69,45 +152,88 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa
|
|||||||
else:
|
else:
|
||||||
weights_sd = torch.load(file, map_location='cpu')
|
weights_sd = torch.load(file, map_location='cpu')
|
||||||
|
|
||||||
# get dim (rank)
|
# get dim/alpha mapping
|
||||||
network_alpha = None
|
modules_dim = {}
|
||||||
network_dim = None
|
modules_alpha = {}
|
||||||
for key, value in weights_sd.items():
|
for key, value in weights_sd.items():
|
||||||
if network_alpha is None and 'alpha' in key:
|
if '.' not in key:
|
||||||
network_alpha = value
|
continue
|
||||||
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
|
||||||
network_dim = value.size()[0]
|
|
||||||
|
|
||||||
if network_alpha is None:
|
lora_name = key.split('.')[0]
|
||||||
network_alpha = network_dim
|
if 'alpha' in key:
|
||||||
|
modules_alpha[lora_name] = value
|
||||||
|
elif 'lora_down' in key:
|
||||||
|
dim = value.size()[0]
|
||||||
|
modules_dim[lora_name] = dim
|
||||||
|
print(lora_name, value.size(), dim)
|
||||||
|
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
# support old LoRA without alpha
|
||||||
|
for key in modules_dim.keys():
|
||||||
|
if key not in modules_alpha:
|
||||||
|
modules_alpha = modules_dim[key]
|
||||||
|
|
||||||
|
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||||
network.weights_sd = weights_sd
|
network.weights_sd = weights_sd
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
class LoRANetwork(torch.nn.Module):
|
class LoRANetwork(torch.nn.Module):
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
# is it possible to apply conv_in and conv_out?
|
||||||
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_UNET = 'lora_unet'
|
LORA_PREFIX_UNET = 'lora_unet'
|
||||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||||
|
|
||||||
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
|
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
self.conv_lora_dim = conv_lora_dim
|
||||||
|
self.conv_alpha = conv_alpha
|
||||||
|
|
||||||
|
if modules_dim is not None:
|
||||||
|
print(f"create LoRA network from weights")
|
||||||
|
else:
|
||||||
|
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
|
|
||||||
|
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
||||||
|
if self.apply_to_conv2d_3x3:
|
||||||
|
if self.conv_alpha is None:
|
||||||
|
self.conv_alpha = self.alpha
|
||||||
|
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||||
loras = []
|
loras = []
|
||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
# TODO get block index here
|
||||||
for child_name, child_module in module.named_modules():
|
for child_name, child_module in module.named_modules():
|
||||||
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
is_linear = child_module.__class__.__name__ == "Linear"
|
||||||
|
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||||
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||||
|
if is_linear or is_conv2d:
|
||||||
lora_name = prefix + '.' + name + '.' + child_name
|
lora_name = prefix + '.' + name + '.' + child_name
|
||||||
lora_name = lora_name.replace('.', '_')
|
lora_name = lora_name.replace('.', '_')
|
||||||
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
|
||||||
|
if modules_dim is not None:
|
||||||
|
if lora_name not in modules_dim:
|
||||||
|
continue # no LoRA module in this weights file
|
||||||
|
dim = modules_dim[lora_name]
|
||||||
|
alpha = modules_alpha[lora_name]
|
||||||
|
else:
|
||||||
|
if is_linear or is_conv2d_1x1:
|
||||||
|
dim = self.lora_dim
|
||||||
|
alpha = self.alpha
|
||||||
|
elif self.apply_to_conv2d_3x3:
|
||||||
|
dim = self.conv_lora_dim
|
||||||
|
alpha = self.conv_alpha
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
@ -130,7 +256,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
lora.multiplier = self.multiplier
|
lora.multiplier = self.multiplier
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
from safetensors.torch import load_file, safe_open
|
from safetensors.torch import load_file, safe_open
|
||||||
@ -240,3 +366,18 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
save_file(state_dict, file, metadata)
|
save_file(state_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_regions(networks, image):
|
||||||
|
image = image.astype(np.float32) / 255.0
|
||||||
|
for i, network in enumerate(networks[:3]):
|
||||||
|
# NOTE: consider averaging overwrapping area
|
||||||
|
region = image[:, :, i]
|
||||||
|
if region.max() == 0:
|
||||||
|
continue
|
||||||
|
region = torch.tensor(region)
|
||||||
|
network.set_region(region)
|
||||||
|
|
||||||
|
def set_region(self, region):
|
||||||
|
for lora in self.unet_loras:
|
||||||
|
lora.set_region(region)
|
@ -515,6 +515,7 @@ def train_model(
|
|||||||
sample_every_n_epochs,
|
sample_every_n_epochs,
|
||||||
sample_sampler,
|
sample_sampler,
|
||||||
sample_prompts,
|
sample_prompts,
|
||||||
|
output_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
|
@ -427,9 +427,9 @@ def train(args):
|
|||||||
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
||||||
})
|
})
|
||||||
|
|
||||||
# uncomment if another network is added
|
if args.network_args:
|
||||||
# for key, value in net_kwargs.items():
|
for key, value in net_kwargs.items():
|
||||||
# metadata["ss_arg_" + key] = value
|
metadata["ss_arg_" + key] = value
|
||||||
|
|
||||||
if args.pretrained_model_name_or_path is not None:
|
if args.pretrained_model_name_or_path is not None:
|
||||||
sd_model_name = args.pretrained_model_name_or_path
|
sd_model_name = args.pretrained_model_name_or_path
|
||||||
@ -639,4 +639,4 @@ if __name__ == '__main__':
|
|||||||
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
Loading…
Reference in New Issue
Block a user