Paper: arXiv
Authors: Lucas Prieto, Melih Barsbey, Pedro A.M. Mediano, Tolga Birdal
If you are unfamiliar with Grokking please check out this YouTube video for a brief overview or else read any paper on Grokking.
Introduction
Prior research has shown that we need some form of regularization for Grokking. Grokking was previously thought as, at first the model learns a complicated solution and essentially memorises the entire training set, but as training is continued, weight decay would push the model towards a simpler solution, and having an implicit assumption that simpler solutions generalize better, we get the grokking phenomenon where the model gets better on the test set.
The authors of this work argue that, we do not need regularization to achieve Grokking. Without regularization, grokking is prevented by absorption errors in the softmax which leads to softmax collapse (that’s a term phrased by the authors). Then they argue that Softmax Collapse is caused by Naive Loss Minimization, i.e. where the gradient in training just scales up the weights to reduce the loss without changing the decision boundary. The authors provide workarounds for these and indeed achieve grokking without regularization. Let’s look at what all these terms mean.
Softmax Collapse
Absorption errors caused due to floating point arithmetic lead to Softmax Collapse
Absorption error
Absorption error in floating-point arithmetic occurs when adding a number to a much larger number results in itself. This happens because, during floating-point addition, the exponents are aligned. If the exponent difference () is greater than or equal to the precision , the significand of is shifted right by at least positions to match the exponent of . This shift causes the significand of to lose all its significant bits within the available precision. Consequently, effectively becomes zero in the computation, and the result of is indistinguishable from due to the limited precision.
Intuitively, absorption errors can occur during FP addition when operands have significantly different magnitudes. For the base is 2 and bits, meaning that adding any number smaller than to 1 will leave 1 unchanged.
Absorption error in Softmax
Softmax Cross-Entropy (SCE) loss is defined as:
where is the neural network, is the data point, is the corresponding logit class to the true label .
Essentially, at the end of training when the model achieves perfect accuracy, the true logit is already large in magnitude when compared to logits of other classes. So, >> causing Absorption error:
So the loss would be:
Yep. The loss would be zero, so there’s no signal for neural network to learn. The model does achieve 100% training accuracy but doesn’t learn any more and so doesn’t generalise to the test set.
There’s a lot in these graphs. (Left) When the training set is small, it’s extremely easy to overfit and achieve 100% accuracy. So the Softmax Collapse occurs and there’s literally no generalisation, 0% on test set. (Middle/Right) As the training set size increases, it’s hard to achieve a perfect training accuracy and there’s no softmax collapse, we see some generalisation. The dotted lines are where generalisation stops (SC starts to occur (allegedly)). As Floating Point precision increases, Softmax Collapse happens much further into the training, proving authors’ claims.
An alternative to Softmax
To mitigate Softmax collapse, authors use a new loss function:

This new function approaches zero more slowly and is not exponential when , so it scales linearly, and the magnitude of logits are relatively small. This reduces the risk of Absorption Errors.

Naive Loss Minimization
But why does Softmax Collapse happen in the first place? When the train accuracy is 100% but still there are better solutions which generalise that the model hasn’t learnt, instead of grokking the model simply reduces the loss by making it’s logits bigger, essentially like “Reward hacking”. The bigger the logits, the lesser the loss, and the bigger the logits, more risk of Absorption Error. Key point being, decision boundary doesn’t change, just the logits get bigger and loss gets smaller.
Authors call this Naive Loss Minimisation. More formally:
The direction is a direction of naive loss minimization if:
The loss is reduced but the decision boundary is essentially the same (just scaled). So the model doesn’t generalise to the test set. The model just “reward hacked” the loss.
This direction of Naive Loss Minimisation for ReLU MLPs, Transformers without bias terms is in the direction of their weights. That is, gradient descent can just scale up the current weights and reduce the loss while not learning anything useful. If suppose there was weight decay, then gradient descent can’t reduce the loss while scaling up the weights. In the process of learning a simple solution, they undergo grokking.
To check this hypothesis, authors measure the cosine similarity between the gradient and weight vectors. The results are very convincing. I’m convinced.

Mitigating Naive Loss Minimization
So now, how do we reduce this Naive Loss Minimization? Pretty simple. We know that the weight are being scaled up. So while updating the weights during training, let’s only consider the part of the (negative) gradient which is orthogonal to the current weights. So in this way, gradient descent can’t just scale up the weights to reduce the loss, there won’t be any absorption error and it should generalise.
To test this, authors do exactly that and voila!

Thoughts
One of the rare grokking papers that is very convincing. Everything just falls into place. A very very cool paper. I suggest reading it.