Paper: arXiv
Authors: Tianzhe Chu, Yuexiang Zha, Jihan Yang, Shengbang Tong, Saining Xie, Dale Schuurmans, Quoc V. Le, Sergey Levine, Yi Ma
Overview
In this paper, authors perform extensive comparison between Supervised Finetuning and Reinforcement learning to determine a model’s ability to generalise past it’s training distribution. They selected 2 tasks. These tasks have 2 variations each, (i) a text only version and (ii) a version needing vision capabilities. They chose Llama-3.2-Vison-11B as their base model and finetuned it on these tasks (separately finetuned base model on each variation). The finetuned Llama-3.2 becomes the initial model. They further conduct SFT on the model on the tasks, and naturally it’s overfit and it performs poorly on out of distribution samples (SFT Memorizes).
To test RL, they take the initial finetuned model and perform RL with a verifier and outcome based reward model using simple PPO. RL performs very well on Out of Distribution test samples too (RL Generalises). They verify that this same result holds for both text only tasks and also tasks involving vision capabilities.
An interesting ablation study is that, they take the base Llama-3.2-Vision-11B (without any finetuning) and directly apply RL on it. It performs very poorly. So by empirical experiments they claim that SFT is needed for RL.
Out-of-Distribution performance across 2 tasks and across two modalities (GP-L,VIRL-L represent text only tasks and GP-VL,VIRL-VL require vision capabilities)
Tasks Setup
The two tasks that they test are:
1) GeneralPoints (GP)
This is similar to the Leetcode 3sum problem. The model is given a list of 4 poker cards (just in text form when testing text-only variation or the images of those cards when testing vision capabilities). Each poker card represents a natural number. The model is allowed to perform any arithmetic operation and return the expression which would sum upto 24.
Just the text is provided when testing for text-only capabilities and just the images are provided when testing for vision capabilitites.
The correct answer would be ((7-5)*10)+4
They further generate two variations of this task. One variation is used to train the model (either SFT or RL) and the other variation is used to test the model. The other variation acts like out-of-distribution samples.
The two variations for text task are:
(i) (In Distribution samples i.e. during the training of the model) In system prompt, the model is instructed that the values of cards ‘J’,‘K’,‘Q’ is 10
(ii) (Out of Distribution samples i.e. during the testing of the model) In system prompt, the model is instructed that the values of cards ‘J’,‘K’,‘Q’,‘A’ are 11,12,13 respectively.
The two variations for vision task are:
(i) (In Distribution samples i.e. during the training of the model) The poker cards in the task are of black.
(ii) (Out of Distribution samples i.e. during the testing of the model) The poker cards in the task are of red.
2) V-IRL
in this task the goal is to navigate to a target location by following a set of instructions that contain spatial information. Essentially the model acts as an agent and needs to find its way from start position to the destination position.
The model is provided with instructions in its prompt like:
First, turn slightly right towards the northeast and walk a short distance until you reach the next intersection,
where you’ll see The Dutch on your right. Next, make a sharp left turn to head northwest. Continue for a while
until you reach the next intersection, where Lola Taverna will be on your right. Finally, turn slightly right to face
northeast and walk a short distance until you reach your destination, Shuka, which will be on your right.
Based on the instructions, the model is expected to navigate through the environment. After every action done by the model, it is provided information about the current state in the form of observation.
This same task is represented in text form too.
The two variations for text task are:
(i) (In distribution) All the directions in the prompt are with respect to cardinal directions (north, south, east, west)
(ii) (Out of Distribution) All the directions in the prompt are specified with respect to relative orientation (left, right, slightly left, slightly right)
The two variations for vision task are:
(i) (In Distribution) As the model can see, the training is done in some city using actual images.
(ii) (Out of Distribution) Testing is done in a completely different city.
Reinforcement Learning Framework
They perform simple PPO using an outcome based reward model. They use sequential revision. That is, at time step t=0, the input contains just the prompt. The model takes an action according to the prompt and gets some feedback (in terms of reward) from the verifier. At time step t=1, context contains initial prompt, model’s action, verifier’s feedback. The model takes an action based on the past data, gets some feedback. At time step t=2, context contains everything at t=1 concatenated with the action taken by the model and the feedback from the verifier. This process continues.
Experiments
They finetune a new base Llama-3.2-Vision-11B on each of the four tasks. They finetuned model is called as init model. This model is further used for SFT and RL.
Naturally, an already finetuned model when further SFT is performed, will overfit the training distribution and generalises poorly. Even further RL doesn’t help in recovering performance.
Performing RL before any finetuning doesn’t help in generalisation too.
Simple RL from init model (after some finetuning) does improve generalisation across all tasks.
There are so many other small ablation studies done. Do check out the paper for the little ablation experiments.
Thoughts
Nothing much interesting happening. Naturally SFT on already finetuned model is going to be very bad. Pretty cool to see RL generalise to OOD samples. They did not try to explain why that is happening though. It would be nice to see a paper focusing only on RL and explaining why RL generalizes.