97 lines
3.5 KiB
Python
97 lines
3.5 KiB
Python
|
# NAI compatible
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
class HypernetworkModule(torch.nn.Module):
|
||
|
def __init__(self, dim, multiplier=1.0):
|
||
|
super().__init__()
|
||
|
|
||
|
linear1 = torch.nn.Linear(dim, dim * 2)
|
||
|
linear2 = torch.nn.Linear(dim * 2, dim)
|
||
|
linear1.weight.data.normal_(mean=0.0, std=0.01)
|
||
|
linear1.bias.data.zero_()
|
||
|
linear2.weight.data.normal_(mean=0.0, std=0.01)
|
||
|
linear2.bias.data.zero_()
|
||
|
linears = [linear1, linear2]
|
||
|
|
||
|
self.linear = torch.nn.Sequential(*linears)
|
||
|
self.multiplier = multiplier
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x + self.linear(x) * self.multiplier
|
||
|
|
||
|
|
||
|
class Hypernetwork(torch.nn.Module):
|
||
|
enable_sizes = [320, 640, 768, 1280]
|
||
|
# return self.modules[Hypernetwork.enable_sizes.index(size)]
|
||
|
|
||
|
def __init__(self, multiplier=1.0) -> None:
|
||
|
super().__init__()
|
||
|
self.modules = []
|
||
|
for size in Hypernetwork.enable_sizes:
|
||
|
self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
|
||
|
self.register_module(f"{size}_0", self.modules[-1][0])
|
||
|
self.register_module(f"{size}_1", self.modules[-1][1])
|
||
|
|
||
|
def apply_to_stable_diffusion(self, text_encoder, vae, unet):
|
||
|
blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
|
||
|
for block in blocks:
|
||
|
for subblk in block:
|
||
|
if 'SpatialTransformer' in str(type(subblk)):
|
||
|
for tf_block in subblk.transformer_blocks:
|
||
|
for attn in [tf_block.attn1, tf_block.attn2]:
|
||
|
size = attn.context_dim
|
||
|
if size in Hypernetwork.enable_sizes:
|
||
|
attn.hypernetwork = self
|
||
|
else:
|
||
|
attn.hypernetwork = None
|
||
|
|
||
|
def apply_to_diffusers(self, text_encoder, vae, unet):
|
||
|
blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
|
||
|
for block in blocks:
|
||
|
if hasattr(block, 'attentions'):
|
||
|
for subblk in block.attentions:
|
||
|
if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
|
||
|
for tf_block in subblk.transformer_blocks:
|
||
|
for attn in [tf_block.attn1, tf_block.attn2]:
|
||
|
size = attn.to_k.in_features
|
||
|
if size in Hypernetwork.enable_sizes:
|
||
|
attn.hypernetwork = self
|
||
|
else:
|
||
|
attn.hypernetwork = None
|
||
|
return True # TODO error checking
|
||
|
|
||
|
def forward(self, x, context):
|
||
|
size = context.shape[-1]
|
||
|
assert size in Hypernetwork.enable_sizes
|
||
|
module = self.modules[Hypernetwork.enable_sizes.index(size)]
|
||
|
return module[0].forward(context), module[1].forward(context)
|
||
|
|
||
|
def load_from_state_dict(self, state_dict):
|
||
|
# old ver to new ver
|
||
|
changes = {
|
||
|
'linear1.bias': 'linear.0.bias',
|
||
|
'linear1.weight': 'linear.0.weight',
|
||
|
'linear2.bias': 'linear.1.bias',
|
||
|
'linear2.weight': 'linear.1.weight',
|
||
|
}
|
||
|
for key_from, key_to in changes.items():
|
||
|
if key_from in state_dict:
|
||
|
state_dict[key_to] = state_dict[key_from]
|
||
|
del state_dict[key_from]
|
||
|
|
||
|
for size, sd in state_dict.items():
|
||
|
if type(size) == int:
|
||
|
self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
|
||
|
self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
|
||
|
return True
|
||
|
|
||
|
def get_state_dict(self):
|
||
|
state_dict = {}
|
||
|
for i, size in enumerate(Hypernetwork.enable_sizes):
|
||
|
sd0 = self.modules[i][0].state_dict()
|
||
|
sd1 = self.modules[i][1].state_dict()
|
||
|
state_dict[size] = [sd0, sd1]
|
||
|
return state_dict
|