A. Kushal Thaman, Stanford University
B. Max Sobol Mark, Stanford University
C. Michael Byun, Stanford University
How are AI chatbots like ChatGPT built? You may have heard that large language models (LLMs), including ChatGPT, are like “autocomplete on steroids”—they generate text by repeatedly predicting the next word. They become good at this prediction task by learning from vast reams of internet text.
This “next-word prediction” style of training forms the backbone of how LLMs are made, but it’s not the whole story. An AI model that has only been trained to naively predict the next word will sometimes generate text that is incoherent or unhelpful. For example, let’s try prompting such a model with a short phrase:
The model (in this case, GPT-3) seems to have decided that the most likely completion for our phrase “live laugh love” was a long list of Etsy-esque keywords! That’s a reasonable guess for the most likely text completion based on internet text, but it probably isn’t what we wanted from the model. If instead we give the same prompt to a chat AI model, we get a much more coherent answer:
This is clearly a much better response. So what’s the difference between these two models? Why do they answer the same prompt so differently?
Part of the answer lies in a technique called reinforcement learning from human feedback, or RLHF. RLHF teaches the “superpowered autocomplete” models to give better responses. Here's how it works:
Initial Training: As before, the AI starts by learning from a massive dataset of internet text. This is where it gets good at predicting the next word in a sentence.
Collecting Feedback: Next, the AI is given a variety of prompts, and it tries to respond to them. Human reviewers then look at these responses and provide feedback. They don't just correct mistakes; they also guide the AI on what kind of answers are more helpful or appropriate.
Learning from Feedback: Now we use this feedback to build a reward model, which is a second AI that can give feedback to the original AI much faster than a human can. The original AI model then gets feedback from the reward model and learns from it, almost like it’s learning from experience: if the reward model says, "This is a better way to answer this question," the AI adapts and remembers that for next time. Over time, it gets better at understanding not just what we say, but what we mean and want to know.
The result is an AI that's not just repeating what it's seen on the internet, but actually engaging with you in a way that's helpful, coherent, and contextually appropriate. RLHF is why chat models like ChatGPT can give such humanlike answers—they're not just predicting words, they're learning how to communicate with us effectively.
Now let’s introduce another important concept from machine learning: overfitting. This isn’t going to sound related at first, so bear with me!
Imagine you're trying to teach a dog to recognize different kinds of fruit. You show it lots of pictures of apples and oranges, and it gets pretty good at telling them apart. But if you only show it red apples, it might get the wrong idea and think all apples are red. This is essentially what overfitting is in the world of machine learning.
When an AI model like a language model is being trained, it learns from examples (just like the dog). Overfitting happens when the model learns too much from the specific details of the training data and not enough about the general rules. It's like memorizing the answers to a test instead of understanding the subject. The model becomes great at dealing with the data it was trained on but struggles when faced with new, unseen situations.
Consider the plot above: as we train our AI (going from left to right), our error on the training data (the blue line) continues to decrease. That’s generally a good thing—it means that our model is learning the training data better. At first, the error on the unseen test data (red line) also decreases—which is good, because it means that our model is successfully generalizing to new examples, and isn’t overfitting to the training data. But after the dotted line ⚠️, the test error starts increasing. Our model, while improving on the training data, is now doing worse on the test data—suggesting that it’s overfitting to the training data.
In the context of AI chatbots, overfitting can lead to some quirky behaviors. For instance, the model might become really good at continuing certain types of sentences or discussions that it saw a lot during training, but it gets confused or gives nonsensical responses when faced with something new or different. It's like having a conversation with someone who's really knowledgeable about a few topics but gets lost as soon as you change the subject.
This is why overfitting is a big deal in AI development. We want our AI not just to repeat what it's seen before, but to understand and adapt to a wide range of topics and questions, just like a well-rounded, knowledgeable human would. To achieve this, developers use various techniques to prevent overfitting, ensuring that the AI remains flexible and versatile in its responses. In general, overfitting can be reduced by adding more training data.
So how is this related to RLHF? In RLHF, when we overly optimize the AI against a specific reward model, it's like the AI is a student who only studies to impress one teacher's style of questioning but then struggles in exams set by other teachers. Similarly, the AI becomes great at responding to the type of feedback it has been trained on but may falter in broader, real-world scenarios. This is overfitting in the context of RLHF—the AI becomes over-tuned to the nuances of the training data and loses its ability to generalize. This can lead to responses that, while technically correct, might miss the mark in natural conversations or in understanding diverse perspectives. It's a tricky balance to maintain—we want our AI to be responsive and accurate, but not at the cost of losing its adaptability and broader understanding.
So how can we mitigate overfitting in RLHF? There are a number of strategies, like early stopping, ensembling, and data augmentation. We’ll come back to that later.
That brings us to the main topic of this post: direct preference optimization (DPO). DPO is a newer approach to RLHF. Unlike traditional RLHF, which involves a two-stage process of first training a reward model and then fine-tuning the AI, DPO simplifies it into a single step. Imagine teaching someone to cook: instead of first writing a cookbook and then having them learn by trying to replicate the recipes, you directly provide feedback on the dishes they make. This direct approach is what DPO does - it fine-tunes the AI based on human preferences in a more straightforward, efficient way.
The advantages of DPO are significant: it's both more stable and more computationally efficient. However, there’s ongoing debate about whether it’s better in practice compared to traditional RLHF; some researchers argue that having an explicit reward model is important. In any case, DPO is an exciting new algorithm which is getting a lot of attention, so it’s worth probing its limitations.
A natural question about this new technique is its susceptibility to overfitting. Can DPO lead to overfitting? If so, how severe is it? How can it be mitigated? These are the key questions that our paper investigates.
We do this by measuring an AI-generated “gold-standard” reward during training, as a proxy for real human preferences. Intuitively, the gold-standard reward stands in for complex, real-world preferences that we want our AI to generalize to. We expect that as we train our AI, our gold-standard reward will go up. If overoptimization occurs, then despite DPO continuing to train, our gold-standard reward will plateau or even decrease.
We show that overfitting does happen with DPO, as expected, for multiple sizes of AI models. Notice how the data trends toward a plateau on the right side of the graphs (though there’s some noise):
What does this mean in practice? We see that after 250 steps of training with DPO, our model produces relatively coherent outputs:
But this changes with further training, and we observe at 1,600 training steps that it produces gibberish:
This matches our expectations of what overfitting to a reward model looks like! Despite training for longer, our model is clearly not giving more helpful answers than it was before.
We’ve just shown that overfitting occurs when using DPO. Overfitting isn’t desirable, so how can we mitigate it? As we noted before, one of the main ways to reduce overfitting is by adding more examples to the training dataset. However, collecting data on human preferences is difficult, time-consuming, and expensive. It would be ideal if we could somehow extend our dataset without actually having to collect additional empirical data. This idea is called data augmentation.
In particular, we use a family of data augmentation techniques called mixup. The idea is simple: we can interpolate between existing data points to generate new, “synthetic” examples. For example, consider two training data points from a hypothetical dataset. One is green with the label “4”, and the other is blue with the label “5”:
We can interpolate between these two data points to create a synthetic data point which is blue-green and has the label “4.5”:
Of course, we don’t know if our synthetic data point is realistic, but it’s probably a decent guess since it’s based on real data. Repeat this process, and you end up with extra training examples that are roughly representative of the original data.
We used mixup to augment our dataset to reduce overfitting in DPO. Our technique is slightly more complex (and novel), since there’s not an obvious way to interpolate between sentences, unlike numbers or colors. We end up doing the interpolations inside of the AI, using its numerical representations of concepts instead of the raw text inputs and outputs. (See our paper for more technical details!)
So how does mixup fare for reducing overfitting? Reasonably well! In particular, mixup seems to work (i.e. improve the reward accuracy) for the largest, most capable AI models we tested (which are still relatively small compared to models like the ones behind ChatGPT):
As we've seen, overfitting is a challenge in both traditional RLHF and direct preference optimization (DPO), and our research sheds light on this issue. We’ve identified that DPO, while efficient and effective, isn’t immune to overfitting. However, the good news is that we're also exploring novel solutions like mixup, which show promise in addressing these challenges.
As we continue to explore and refine methods like DPO, we're moving closer to making AI more reliable in understanding and responding to diverse human inputs. In practical terms, this could lead to AI that understands and aligns with human preferences, yet remains versatile and adaptable—all while being simpler to implement and train. This balance is key to creating AI that is not just smart, but also more in tune with what we, as humans, find helpful and relevant.