Merge pull request #10266 from nero-dv/dev
Update sub_quadratic_attention.py
This commit is contained in:
commit
c9e5b92106
@ -202,13 +202,22 @@ def efficient_dot_product_attention(
|
|||||||
value=value,
|
value=value,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
# slices of res tensor are mutable, modifications made
|
||||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
# to the slices will affect the original tensor.
|
||||||
res = torch.cat([
|
# if output of compute_query_chunk_attn function has same number of
|
||||||
compute_query_chunk_attn(
|
# dimensions as input query tensor, we initialize tensor like this:
|
||||||
|
num_query_chunks = int(np.ceil(q_tokens / query_chunk_size))
|
||||||
|
query_shape = get_query_chunk(0).shape
|
||||||
|
res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:])
|
||||||
|
res_dtype = get_query_chunk(0).dtype
|
||||||
|
res = torch.zeros(res_shape, dtype=res_dtype)
|
||||||
|
|
||||||
|
for i in range(num_query_chunks):
|
||||||
|
attn_scores = compute_query_chunk_attn(
|
||||||
query=get_query_chunk(i * query_chunk_size),
|
query=get_query_chunk(i * query_chunk_size),
|
||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
)
|
||||||
], dim=1)
|
res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
Loading…
x
Reference in New Issue
Block a user