Help, my LLM keeps forgetting!

March 27, 2024 / By Nathan Brake

Recently I built an application to demonstrate some machine learning models that our team has been working on. Before beginning to write the code, I found it useful to have a conversation with a large language model (LLM) like ChatGPT to gather some insights about what technologies and designs would work best for my task. Although new code is being written all the time, many of the most successful software design principles have been around for a long time, and an LLM is well suited to help explain them to me.

I’ve noticed that if my chat is brief, the LLM does a great job recommending designs. However, as the conversation continues, the LLM starts to forget about things that I said or clarified during our conversation. For instance, if in my fifth response in the chat, I include something like, “My application needs to validate the input parameter,” but then I don’t ask it to do anything with that information until the 50th chat response, it may seem to forget that I had this requirement. Why is this? This is a fundamental issue with current LLM designs: Although they technically can “see” everything we are sending it, it doesn’t guarantee that they actually pay attention to everything.

The “Transformer” architecture is the magic design behind many modern LLMs. Although there are lots of tweaks being made to the design (an active area of research), a core component of the design is called the “self-attention” mechanism. The idea is that the model learns to identify the value of each word by relating it in terms of the other words’ importance.

So, if the LLM is trained by being given the text, “Mary had a little lamb,” the model learns to represent the word “lamb” in terms of the words “Mary,” “had,” “a,” and “little.” That’s what makes the model powerful as a next-word predictor. With this method, the LLM can be trained to understand what the word “lamb” means in the sentence and learns that “lamb” is a reasonable word to predict if the input starts with “Mary had a little.”

With a simple sentence like this, it’s relatively easy for the model to understand what is going on, but if this scales to having a text of 1,000 words, now a model must learn the representation of one word as it relates to the 999 other words. Intuitively then, it makes sense that this is a much harder task, and it’s more likely that the model won’t behave as expected. It’s difficult to distill this idea into a short paragraph, but this lecture from John Hewitt at Stanford is probably my favorite technical explanation of how self-attention and Transformers work. It’s a bit math-y, but if you have had some exposure to linear algebra or artificial intelligence (AI), it is excellent for explaining the intuition behind this building block that is the core of generative AI.

In support of our experience of an LLM “forgetting”, this paper from Stanford, UC Berkely, and Samaya AI (The aforementioned John Hewitt happens also to be an author here.) experiments to gather concrete statistics about this phenomenon. They find that although modern LLMs technically support context lengths of 32k and above (aka an input + output text of ~32,000 words), this doesn’t necessarily mean that they are good at extracting and using all of the information that may have been provided. In practice, they tend to excel at referencing information that comes at the beginning or at the end of the conversation but display a significant drop in performance when required to use information that exists in the middle of the conversation.

Until significant improvements are made to modern LLMs and integrated into our daily lives, this leads to an important takeaway: if you want to make sure the LLM pays attention to everything you need, it is crucial to tell it as much as you can at the beginning of the conversation and then reiterate every important point later in the conversation. This will help improve the ability of the systems that currently can forget what you said in the middle of the conversation.

Nathan Brake is a machine learning engineer and researcher at 3M Health Information Systems.