add --xformers-flash-attention option & impl
This commit is contained in:
parent
184e23eb89
commit
3262e825cc
@ -290,7 +290,19 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
|
||||||
|
if shared.cmd_opts.xformers_flash_attention:
|
||||||
|
op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
||||||
|
fw, bw = op
|
||||||
|
if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
||||||
|
# print('xformers_attention_forward', q.shape, k.shape, v.shape)
|
||||||
|
# Flash Attention is not availabe for the input arguments.
|
||||||
|
# Fallback to default xFormers' backend.
|
||||||
|
op = None
|
||||||
|
else:
|
||||||
|
op = None
|
||||||
|
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op)
|
||||||
|
|
||||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
@ -365,7 +377,17 @@ def xformers_attnblock_forward(self, x):
|
|||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v)
|
if shared.cmd_opts.xformers_flash_attention:
|
||||||
|
op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
||||||
|
fw, bw = op
|
||||||
|
if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)):
|
||||||
|
# print('xformers_attnblock_forward', q.shape, k.shape, v.shape)
|
||||||
|
# Flash Attention is not availabe for the input arguments.
|
||||||
|
# Fallback to default xFormers' backend.
|
||||||
|
op = None
|
||||||
|
else:
|
||||||
|
op = None
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, op=op)
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x + out
|
return x + out
|
||||||
|
@ -57,6 +57,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
|
|||||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||||
|
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||||
|
Loading…
Reference in New Issue
Block a user