Paper: arXiv
Authors: Xiao-Wen Yang, Xuan-Yi Zhu, Wen-Da Wei, Ding-Chu Zhang, Jie-Jing Shao, Zhi Zhou, Lan-Zhe Guo, Yu-Feng Li
Date of Publication: 6th February 2025

Overview

In this paper authors train an LLM (along with their inference mechanism) which can backtrack in its reasoning trace and continue exploring other branches. There isn’t much novelty in this paper. They just choose a single task where the states are easily verifiable. So the state from where to backtrack can easily be determined. They create a synthetic dataset where they plant errors in the reasoning trace and immediately follow the error with token. They fine tune an LLM on the dataset. And during inference, whenever LLM outputs token, they just roll back a step and continue generation.
This works because in the task that authors chose, β€˜steps’ or β€˜states’ are very easily verifiable and separable. So rolling back to a previous step is trivial (unlike when truly reasoning in natural language).

The Dataset

LLM is provided with 4 numbers and a target number. It needs to construct the target number using the 4 numbers and the basic arithmetic operations (addition, subtraction, multiplication, and division).

An example:

Each step is clearly distinguishable and verifiable.

Authors construct a synthetic dataset for the above task. Authors introduce 3 types of errors in a portion of the dataset:

  1. They replace a correct step with an erroneous step.
  2. They make a computation error
  3. They violate the rules and not restrict to the 4 operations.

After introducing an error, token is immediately appended.
Final dataset includes correct samples and backtracking (incorrect) samples.
Authors finetune llama-3.2-1B and llama-3.2-3B on the task.

Training

Authors use the following loss function:

Where is the standard next token prediction loss:

is the backtracking loss:

The first term is the regular next token prediction loss, where is the incomplete/incorrect solution.
The second term is loss specifically focusing on the token. is the query appended with the intermediate steps and the erroneous step. are the number of erroneous samples in the batch.

Inference with Backtracking

There are 3 phases in the inference:

  1. Expansion: Given the query, outputs are sampled from the model. They are separated into 2 sets, candidate set and discarded set. Predictions without the token are directly added to the candidate set, while those containing the token are processed further in the next phase.
  2. Backtracking: During the backtracking phase, algorithm select predictions containing the token. They are rolled back by 1 step and are expanded again.
  3. Selection: Finally, in the selection phase, authors compute the scores for all candidate reasoning paths by utilizing the negative perplexity as the metric, and subsequently return the result with the highest score.
Inference

Results

Results

They compare it with all the other ways of finetuning and decoding. N is the number of predictions used in the exploration, and b is the maximum number of times backtracked in exploration. Even without any backtracking (b=0), it improves performance over other regular methods.

There are few other small experiments done in the paper that I am not covering as I felt they weren’t really significant.

Thoughts

I am a huge search-based reasoning/MCTS methods fan. This seems so simple and something that would work only on toy datasets. If only each step is so easily verifiable and separable like here.
A step towards actual working MCTS though. I can’t wait for such a model which actually performs a tree search at inference when reasoning in natural language.