Paper: arXiv
Author: Ken M. Nakanishi
Date of Publication: 31st January 2025

Overview

Softmax that is used in scaled dot product attention cause attention scores to flatten as context length increases. This reduces the model’s ability to prioritize to key information in context and also generalise to longer contexts not seen during training. To fix this author modifies softmax and introduces . is defined as:

where, is context length and is a learnable parameter shared amongst all layers and heads.
Results show that performs much, much better at information retrieval tasks and also generalises better to longer contexts even without being trained on longer contexts.

Problem with Softmax

Softmax transforms an input vector into a vector that can be interpreted as a probability distribution, where all elements are non-negative and sum up to one.

In the attention layers of transformer, input vector size increases as the context length grows. Softmax plays a critical role in computing attention scores over all tokens in the context, determining how much β€˜attention’ is allocated to each token. When grows, the denominator in softmax increases while the numerator remains independent of . As a result the resultant distribution becomes increasingly flat (attention fading).
This reduces the model’s ability to focus on key tokens in context and also reduces ability to generalise to longer contexts.

Scalable-Softmax (SSMax)

is defined as:

also transforms the input vector into a probability distribution as it can clearly be seen from the definition. However the key difference is in the dependence of exponential base on input vector size . This design helps in mitigating attention fading and the resulting attention scores remains focused on the key tokens.
Author provides very nice justifications for .

Rationale Behind the Design of SSMax

To investigate the optimal variant of softmax, author replaced softmax with the following function at all layers and heads:

where, and are learnable parameters unique to each layer and head. And represents learnable parameters shared across all layers and heads, depending solely on the input vector size . denotes the size of context length used during training. Author trains a model with the above function replacing softmax in scaled dot product attention.

Results of training with pn

followed a logarthimic relationship of the form,
This finding suggested that softmax in attention mechanism could benefit from reformulation as:

where, and are layer and head specific learnable parameters. is referred as bias. Based on further evaluation (as we will see in Results), omitting turns out to be better and thus author arrived at .

Justification for the Design of SSMax

Let be an input vector of size . Let denote its maximum, second maximum and minimum elements, respectively. Let .

When is processed by softmax, is transformed as:



We can replace denominator by to obtain an upper bound.



Multiplying and Dividing RHS by we get,

As we can clearly see, as , the maximum element of the output vector produced by Softmax approaches zero.





On the other hand, when is processed by , is transformed as



Assuming , we can obtain an upper bound for RHS.



Similarly we can obtain a lower bound.



Now we have,



The maximum element output by exhibits the following properties:

If , the lower bound approaches . So the output of approaches . Meaning attention is focused on the element with the highest value.
If , the upper bound approaches . So the output of approaches . Meaning attention is distributed across all the elements.

Thus, ensures that attention is focused on elements whose values exceed others by approximately , while distributing attention when all values are within a range of approximately .

We can very easily convert softmax attention to attention by

Results

Results of SSMax

The loss during training. Author reports is better during regular training too.


Results of SSMax being used in extended context lengths

The grey dotted line is the context window at which models were trained. And models’ context length was extended by simply increasing the of RoPE with no additional training.


Results of SSMax in Needle in a Haystack
Results when evaluated on Needle-in-a-haystack.

Overall, results look very good with barely any extra learnable parameters. Big if true.

Thoughts

I really liked the paper. The intuition behind makes sense. The results look pretty good too with almost zero additional compute cost. And very easy to implement. Big if true.
Combining this with differential attention may increase evaluation results on Needle-in-a-haystack even more. It would be nice to try that out.