language:
- en
license: cc-by-sa-4.0
datasets:
- euclaise/TinyCoT
- euclaise/reddit-instruct-curated
- sablo/oasst2_curated
ReMask: Improving autoregressive language models via regularized masking
Background
Self-Play Finetuning (SPIN) is a recent finetuning method which outperforms standard supervised finetuning (SFT). Instead of just performing next-token prediction, SPIN it an iterative method which contrasts generations from the previous iteration of the model with the ground-truth completions. Unlike methods like reinforcement learning or ranking losses, SPIN does not require preference data, which makes it an attractive method since preference data can be hard to gather. However, SPIN's popularity has been limited by the need to repeatedly generate sequences from the model -- generation is much slower than training, so SPIN is much more slow and expensive compared to SFT.
With this problem in mind, I sought out to create an alternative to SPIN which doesn't require generation.
Why does SPIN work?
SFT trains models to predict the next token given all the ground-truth previous tokens. However, in generation, the model doesn't have access to a ground-truth to predict from, and instead repeatedly predicts on top of its own predictions. This creates a bias known as "exposure bias": Models often can pick reasonable choices for the next token on average, but can't keep this up for the full sequence. In particular, it might be easy to predict a reasonable next token, but much more difficult to predict the full sequence.
For instance, consider the following case:
The astronomer pointed his telescope at the distant star, hoping to see
The correct prediction here might be "signs of life.". However, the model might predict "and" rather than "signs", since "and" is reasonable in the immediate context - it's gramatically correct, but implies a strange ending to the sentence. As a result, the model might end up with something like "The astronomer pointed his telescope at the distant star, hoping to see and hear." - which makes little sense.
SPIN's advantage over SFT likely comes from its partial mitigation of exposure bias. SPIN doesn't only train the model to predict the next token accurately, it repeatedly trains the model to identify and fix discrepancies between its generations and the ground-truth. In order to do this, the model must implicitly learn to think ahead, as exposure bias is likely what causes many of the discrepancies.
How can we simplify this?
Unfortunately, explicitly predicting ahead for many steps is very expensive, and considering full model generations requires a slow generation process.
An obvious option is to simply randomly corrupt tokens in the sequence. The model must keep an internal estimate of what the corrupted tokens ought to be in order to predict the token after them, forcing the model to think ahead.
The most obvious ways to do this are to randomly replace input tokens with a special [mask]
token, or to randomly replace input tokens with other random tokens.
These approaches were tried in Masked Thought, albeit with somewhat different motivations.
However, these approaches have a problem: Models can detect when a token is [mask]
or is highly unlikely, so the model may only learn to think ahead when the corruptions are present.
To avoid this issue, we can run the model twice - once with a masked sequence, and once on the full sequence.
Then, we penalize deviations between these two runs, which forces the model to act the same regardless of if the [mask]
token is present or not.
This approach was initially introduced with R-TeaFor for abstractive summarization, but can be easily applied to standard generation tasks too.
ReMask and ReMask-CoT:
ReMask applies an approach similar to R-TeaFor to typical chat/instruction tuning.
Consider the following chat interaction:
User: What is 1+1?
Assistant: 1+1=2
User:
The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
We then compute a divergence loss D(p_masked, p_full)
between the two predictions. For this, I used the average of the backwards and forwards KL divergences between the predictions.
Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
loss = 0.5*(CE(p_masked, labels) + CE(p_full, labels)) + weight*D(p_masked, p_full)
ReMask-CoT:
For CoT tasks where the reasoning is explicitly separated from the answer, we can add some further improvements.
First, note that CoT rationales are noisy -- there are many correct rationales which might lead to the same correct answer, and rationales are impacted by things like writing style which don't matter for the actual correctness of the reasoning.
Keeping this in mind:
- We also randomly mask a small portion of the labels of the rationale, but not the answer, such that an accurate answer is more important than a rationale that is word-for-word identical to the annotated rationale.
- The exact answer is always important and is always a few tokens. Hence, we do not mask the labels or input tokens for the answer value.
- Rarely, we ignore the rationale labels entirely, such that the model is only pushed to learn what leads to the best answer.
Results
I trained StableLM-3B-4e1t repeatedly on https://huggingface.co/datasets/euclaise/TinyCoT, along with 1000 examples from reddit-instruct-curated and 1000 examples from oasst2-curated.
I trained once with ReMask/ReMask-CoT, once without regularization to match Masked Thought (w/ partial label-masking for CoT), and once with SFT.
If my hypothesis regarding exposure bias is correct, ReMask should significantly improve generative benchmarks like GSM8K, but would not necessarily improve logprob-based benchmarks like ARC-c (as implemented by the evaluation harness):
Here are some benchmark results, computed using the the LM Evaluation Harness with vllm:
Model | GSM8K (strict, 5-shot) | ARC-c (acc_norm, 25-shot) |
---|---|---|
SFT | 24.34% | 42.92% |
Masked Thought | 24.18% | 43.60% |
ReMask | 27.90% | 43.26% |
As I expected, it improves GSM8K, but doesn't do much to ARC.
Training details
- Framework: PyTorch Lightning
- Optimizer: Lilith
- Training sequence length: 256
- Input masking probability: 40%
- Label masking probability: 10%
- Answer-only (full rationale label masking) probability: 10%
- Batch size: 16, accumulated to 256
- Epochs: 6
- Learning rate: 1e-5
- Learning rate schedule: One Cycle, cosine, no cycle_momentum
- Regularization weight: 0.1
Prompt format
The format for reddit-instruct and oasst2 was:
<|user|>
[insert instruction here]
<|assistant|>
[insert response here]
<|user|>
...
The format for TinyCoT was:
<|user|>
[insert instruction here]
<|rationale|>
[insert reasoning here]
<|answer|>
[insert direct answer here]