KohyaSS/finetune/hypernetwork_nai.py

97 lines
3.5 KiB
Python
Raw Normal View History

# NAI compatible
import torch
class HypernetworkModule(torch.nn.Module):
def __init__(self, dim, multiplier=1.0):
super().__init__()
linear1 = torch.nn.Linear(dim, dim * 2)
linear2 = torch.nn.Linear(dim * 2, dim)
linear1.weight.data.normal_(mean=0.0, std=0.01)
linear1.bias.data.zero_()
linear2.weight.data.normal_(mean=0.0, std=0.01)
linear2.bias.data.zero_()
linears = [linear1, linear2]
self.linear = torch.nn.Sequential(*linears)
self.multiplier = multiplier
def forward(self, x):
return x + self.linear(x) * self.multiplier
class Hypernetwork(torch.nn.Module):
enable_sizes = [320, 640, 768, 1280]
# return self.modules[Hypernetwork.enable_sizes.index(size)]
def __init__(self, multiplier=1.0) -> None:
super().__init__()
self.modules = []
for size in Hypernetwork.enable_sizes:
self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
self.register_module(f"{size}_0", self.modules[-1][0])
self.register_module(f"{size}_1", self.modules[-1][1])
def apply_to_stable_diffusion(self, text_encoder, vae, unet):
blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
for block in blocks:
for subblk in block:
if 'SpatialTransformer' in str(type(subblk)):
for tf_block in subblk.transformer_blocks:
for attn in [tf_block.attn1, tf_block.attn2]:
size = attn.context_dim
if size in Hypernetwork.enable_sizes:
attn.hypernetwork = self
else:
attn.hypernetwork = None
def apply_to_diffusers(self, text_encoder, vae, unet):
blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
for block in blocks:
if hasattr(block, 'attentions'):
for subblk in block.attentions:
if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
for tf_block in subblk.transformer_blocks:
for attn in [tf_block.attn1, tf_block.attn2]:
size = attn.to_k.in_features
if size in Hypernetwork.enable_sizes:
attn.hypernetwork = self
else:
attn.hypernetwork = None
return True # TODO error checking
def forward(self, x, context):
size = context.shape[-1]
assert size in Hypernetwork.enable_sizes
module = self.modules[Hypernetwork.enable_sizes.index(size)]
return module[0].forward(context), module[1].forward(context)
def load_from_state_dict(self, state_dict):
# old ver to new ver
changes = {
'linear1.bias': 'linear.0.bias',
'linear1.weight': 'linear.0.weight',
'linear2.bias': 'linear.1.bias',
'linear2.weight': 'linear.1.weight',
}
for key_from, key_to in changes.items():
if key_from in state_dict:
state_dict[key_to] = state_dict[key_from]
del state_dict[key_from]
for size, sd in state_dict.items():
if type(size) == int:
self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
return True
def get_state_dict(self):
state_dict = {}
for i, size in enumerate(Hypernetwork.enable_sizes):
sd0 = self.modules[i][0].state_dict()
sd1 = self.modules[i][1].state_dict()
state_dict[size] = [sd0, sd1]
return state_dict