Fix upcast attention dtype error.

Without this fix, enabling the "Upcast cross attention layer to float32" option while also using `--opt-sdp-attention` breaks generation with an error:

```
  File "/ext3/automatic1111/stable-diffusion-webui/modules/sd_hijack_optimizations.py", line 612, in sdp_attnblock_forward
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::Half instead.
```

The fix is to make sure to upcast the value tensor too.
This commit is contained in:
Alexander Ljungberg 2023-06-06 21:45:30 +01:00 committed by GitHub
parent baf6946e06
commit d9cc0910c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -605,7 +605,7 @@ def sdp_attnblock_forward(self, x):
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype dtype = q.dtype
if shared.opts.upcast_attn: if shared.opts.upcast_attn:
q, k = q.float(), k.float() q, k, v = q.float(), k.float(), v.float()
q = q.contiguous() q = q.contiguous()
k = k.contiguous() k = k.contiguous()
v = v.contiguous() v = v.contiguous()