make it possible to use hypernetworks without opt split attention
This commit is contained in:
parent
97bc0b9504
commit
f7c787eb7c
@ -4,7 +4,12 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modules import devices
|
|
||||||
|
from ldm.util import default
|
||||||
|
from modules import devices, shared
|
||||||
|
import torch
|
||||||
|
from torch import einsum
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
@ -48,15 +53,36 @@ def load_hypernetworks(path):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def apply(self, x, context=None, mask=None, original=None):
|
|
||||||
|
|
||||||
|
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
|
q = self.to_q(x)
|
||||||
if context.shape[1] == 77 and CrossAttention.noise_cond:
|
context = default(context, x)
|
||||||
context = context + (torch.randn_like(context) * 0.1)
|
|
||||||
h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
|
hypernetwork = shared.selected_hypernetwork()
|
||||||
k = self.to_k(h_k(context))
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||||
v = self.to_v(h_v(context))
|
|
||||||
|
if hypernetwork_layers is not None:
|
||||||
|
k = self.to_k(hypernetwork_layers[0](context))
|
||||||
|
v = self.to_v(hypernetwork_layers[1](context))
|
||||||
else:
|
else:
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||||
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return self.to_out(out)
|
||||||
|
@ -8,7 +8,7 @@ from torch import einsum
|
|||||||
from torch.nn.functional import silu
|
from torch.nn.functional import silu
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At
|
|||||||
|
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations():
|
||||||
|
undo_optimizations()
|
||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention_v1:
|
if cmd_opts.opt_split_attention_v1:
|
||||||
@ -30,7 +32,7 @@ def apply_optimizations():
|
|||||||
|
|
||||||
|
|
||||||
def undo_optimizations():
|
def undo_optimizations():
|
||||||
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
|
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user