Paper: arXiv
Authors: Jingyang Yuan, Huazuo Gao, Damai Dai, Junyu Luo, Liang Zhao, Zhengyan Zhang, Zhenda Xie, Y. X. Wei, Lean Wang, Zhiping Xiao, Yuqing Wang, Chong Ruan, Ming Zhang, Wenfeng Liang, Wangding Zeng
Date of Publication: 16th February 2025
Overview
Itβs a paper from whale bros.
Attention is quadratic, itβs very expensive to compute. KV-cache grows linearly with sequence length, itβs very expensive to store. We need to fix these issues.
In this paper, authors introduce NSA. They use 3 kinds of attention maps, and merge all 3 to compute the final attention output. Instead of using all the keys and values to calculate attention, they derive a smaller set of new keys and values, and calculate attention using those. There is no drop in the performance when compared to regular multi-head attention.
Preliminaries
Attention
Let be the query for token at position in the sequence length. Let , be all the keys and values for the token and also all the tokens preceding it.
Attention output is defined as:
and . So is a scalar. Intuitively, (also known as attention score) gives βhow relevant is token to token β. Based on the attention scores, we take a weighted average of all the values, to get the attention output . Denominator is just a normalizing constant.
As we can clearly see from this formulation, we need keys and values of all the previous tokens to calculate attention output during inference.
During training, we have access to all future tokens, attention calculation is bound by compute. We donβt decode the tokens in auto-regressive fashion, everything happens parallelly, only bounded by compute.
During inference, we need to store all the previous keys and values to decode the next token. We must move entire Keys and Values that weβve cached between Global GPU memory and shared block memories (as computing is when memory is accessed from shared block memory). Moving the caches is very slow. This whole thing makes attention memory-bound during inference.
Native Sparse Attention
Overall Framework
Instead of storing entire KV cache, , we can store alternate, densely packed of much lower dimension .
are dynamically constructed based on query, key and value vectors. And attention is computed exactly like described above.
Instead of having a single strategy to reduce the dimensions, we can have multiple strategies and obtain multiple and combine them via:
In this paper, authors explore 3 mapping strategies , representing compression, selection, and sliding window for keys and values. is the gate score for the corresponding strategy , derived from input features via an MLP and a sigmoid.
Let denote the total number of remapped keys/values:
Authors ensure .
Token Compression
Let be the block size. Let be the sliding stride between two blocks.
is a learnable MLP with intra-block positional encoding. takes in number of keys (or values) and compresses them into a single key.
Here is an animation that I made which should clear it up:
and in the above animation. As long as we donβt miss out any information. Very similar to a convolution.
The compressed keys, . We have massively reduced the dimensions.
We do the exact same thing for values too.
Token Selection
Using only compressed keys and values, we may lose some fine-grained information. So we explicitly select a few important tokens whose queries and keys we preserve as it is without any compression.
Now how do we identify these important tokens?
Importance Score Computation
We can get an importance score for each of our compressed keys and value blocks by computing the attention scores.
Suppose if Group Query Attention or Multi-Query Attention is being used, then
where in the superscript denotes the head index, and is the number of query heads in each group.
After obtaining the importance scores, we select the top-n important blocks and simply use all the original uncompressed keys and values corresponding to the block. Let be the selected keys and values The selected important keys and values participate in attention computation with as defined above.
Sliding Window
In attention mechanism, local patterns dominate the learning process. Those patterns are easier for the model to pick up, and they have a larger impact on modelβs output. They potentially can prevent the model from learning to effectively compress and select the tokens.
To address this issue, they introduce a local sliding window, where all the keys and queries are preserved.
Pretty straight forward.
Note that once we have obtained a sparse representation for keys and values (using the 3 mapping techniques), we no longer need the original uncompressed keys and values. We can just discard them and save a lot in memory.
After combining all three categories of keys and values , we compute the attention as:

Kernel Design
Algorithm is only as efficient as its implementation. They write an efficient kernel in Triton implementing this.
The main novelty about their kernel is in the case of grouped query attention (GQA). They compute the sparse keys and values (that we have just discussed above) as usual. Consider the case of GQA. Query heads share key and value heads amongst themselves.

So in their kernel, they first move all the query heads into the shared memory.
And then when computing attention, they load continuous blocks of sparse key and value blocks into shared memory, based on the indices (they map the keys and values to corresponding queries), they compute attention and free up the memory of keys and values. They move the attention output back to global memory (HBM).

Results


It has 100% Needle-in-a-Haystack up-to 64k context length.
