WebUI/modules/sub_quadratic_attention.py
Louis Del Valle c8732dfa6f
Update sub_quadratic_attention.py
1. Determine the number of query chunks.
2. Calculate the final shape of the res tensor.
3. Initialize the tensor with the calculated shape and dtype, (same dtype as the input tensors, usually)

Can initialize the tensor as a zero-filled tensor with the correct shape and dtype, then compute the attention scores for each query chunk and fill the corresponding slice of tensor.
2023-05-10 22:05:18 -05:00

224 lines
7.6 KiB
Python

# original source:
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2
from functools import partial
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple, List
def narrow_trunc(
input: Tensor,
dim: int,
start: int,
length: int
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
class SummarizeChunk:
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> AttnChunk: ...
class ComputeQueryChunkAttn:
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> AttnChunk:
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
def _query_chunk_attention(
query: Tensor,
key: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
) -> Tensor:
batch_x_heads, k_tokens, k_channels_per_head = key.shape
_, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk:
key_chunk = narrow_trunc(
key,
1,
chunk_idx,
kv_chunk_size
)
value_chunk = narrow_trunc(
value,
1,
chunk_idx,
kv_chunk_size
)
return summarize_chunk(query, key_chunk, value_chunk)
chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= torch.unsqueeze(max_diffs, -1)
chunk_weights *= max_diffs
all_values = chunk_values.sum(dim=0)
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> Tensor:
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
return hidden_states_slice
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
):
"""Computes efficient dot-product attention given query, key, and value.
This is efficient version of attention presented in
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
Args:
query: queries for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
key: keys for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
value: values to be used in attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
query_chunk_size: int: query chunks size
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
Returns:
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
"""
batch_x_heads, q_tokens, q_channels_per_head = query.shape
_, k_tokens, _ = key.shape
scale = q_channels_per_head ** -0.5
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
def get_query_chunk(chunk_idx: int) -> Tensor:
return narrow_trunc(
query,
1,
chunk_idx,
min(query_chunk_size, q_tokens)
)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
_get_attention_scores_no_kv_chunking,
scale=scale
) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial(
_query_chunk_attention,
kv_chunk_size=kv_chunk_size,
summarize_chunk=summarize_chunk,
)
)
if q_tokens <= query_chunk_size:
# fast-path for when there's just 1 query chunk
return compute_query_chunk_attn(
query=query,
key=key,
value=value,
)
# slices of res tensor are mutable, modifications made
# to the slices will affect the original tensor.
# if output of compute_query_chunk_attn function has same number of
# dimensions as input query tensor, we initialize tensor like this:
num_query_chunks = int(np.ceil(q_tokens / query_chunk_size))
query_shape = get_query_chunk(0).shape
res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:])
res_dtype = get_query_chunk(0).dtype
res = torch.zeros(res_shape, dtype=res_dtype)
for i in range(num_query_chunks):
attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
)
res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores
return res