Merge pull request #11066 from aljungberg/patch-1
Fix upcast attention dtype error.
This commit is contained in:
commit
806ea639e6
@ -602,7 +602,7 @@ def sdp_attnblock_forward(self, x):
|
|||||||
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k, v = q.float(), k.float(), v.float()
|
||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
|
Loading…
Reference in New Issue
Block a user