Please see the accompanying notebook for this article.
Late chunking
The motivation for this discussion is to pull meaningful information out of an earnings call. Specifically, we'll look at the Zoom 2024 Q3 earnings call, which you can find a transcript of at The Motley Fool This call is about 300 lines of text, where each line is associated with a speaker and their role in the company.
Some things to note:
- - It's about 300 lines of text
- - The call is basically a disclosure from members of Zoom leadership about Zoom's operations
- - There are a number of analysts also on the call who ask questions
Conventional RAG
In a conventional RAG approach, we'd use an embedding model to compute the embedding of each line on its own.
line_embeddings: list[torch.Tensor] = []
for line in lines:
tokenized_line = tokenizer(line)
line_embedding = model(**tokenized_line)["pooler_output"]
line_embeddings.append(line_embedding)
pass
pass
The idea here is that if we compute the embedding of given query, like 'Who were the attendees on the call?', the nearest embedded lines should be the most relevant to our question.
The issue with conventional chunking
Unfortunately, you'll find that in this example, line-by-line embeddings like this aren't that helpful at answering this specific question. Unless there is a line of text in which someone explictly called out the attendee names, we aren't likely to match with a line that helps us. This is a pretty big weakness of the conventional chunking approach in general - the size of our chunk limits how much context the embedding for that chunk contains. In other words, the only way to get all the meaning of our call transcript into an embedding is to embed the entire call transcript! This might help us identify the call as being relevant, but is definitely not going to help us find relevant lines of the call.
A short dive into embeddings
There is a solution to the above problem, but to explain it, we'll need to zoom out and look at how embeddings are computed in the first place. Recall that you can think of the embedding of a sentence as a spatial representation of that sentence, where sentences that are more similar in meaning are closer together as embeddings.
Word2Vec
Maybe the first technology that was adopted to produce 'embeddings' was word2vec (published in 2013). Word2vec learns an embedding for each word in the language it's trained on, so it is only intended to indicate which words are similar to each other, not whole sentences. While you're unlikely to find a modern use for word2vec, it was a significant technology along the way to modern embedding models and is worth understanding a little. Word2vec actually predates the widespread adoption of the attention mechanism in language modeling, and was trained a little differently than modern embedding models: It tries to maximize the probability of a given document based on the vector representations it assigns to each word.
Modern embedding models
Modern embedding models usually embed whole sentences of words, not just words. Generally, they're trained to minimize the distance between the embeddings of similar sentences. The embedding model we will use is jina-embeddings-v2-base-en, a BERT model that has been fine-tuned on a 'query-target matching' task. By 'query-target matching', we mean that the training data looks like pairs of a query and a matching document chunk for that query, and the model is trained to minimize the distance between these pairs.
Specifically, recall that a BERT model (and indeed any transformer) produces a 'hidden state' vector for each token in its input:
For the training task though, we need a single embedding for the query and a single embedding for the target phrase. We get a single embedding for a sentence by 'pooling' all the hidden state vectors somehow. Here is the BERT pooler that Jina v2 uses:
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
You can see that the BERT pooler simply picks the hidden state of the first token as the representative for the whole sentence. This works for a couple of reasons:
- 1. BERT applies bi-directional attention, which means that hidden states for tokens at the beginning of the sentence will incorporate information from the entire sentence.
-
2. The BERT tokenizer prepends a special
[CLS]
token to the start of input - so it's the hidden state of this[CLS]
token that we're using to 'pool' the whole sentence, and not a random token that actually appears in the sentence.
Late chunking
We're ready to discuss a solution to the issue of giving greater context to our embeddings. In keeping with the earnings call example, when we embed each line, we want to incorporate as much context as possible into that embedding - ideally more than just the line itself.
Consider the output of the BERT model prior to computing a pooled embedding over, say, 10 lines of the transcript.
We have a hidden state vector for every token in those 10 lines - but by default only the hidden state of the [CLS]
token would be used to create an embedding.
In 'late chunking', we'll group the hidden state vectors for each token into a group for each line, and then pool each group to produce an embedding for each line.
We don't have a special [CLS]
token we can rely on here though, so we'll need to pool with something like 'mean pooling', where we simply average the values across the vectors in a group.
The reason this is an improvement over the default embedding method is that each token's embedding has the context of all 10 lines.
This gives us matches like the following for the query "Who were the attendees on the call?":
line # 18 Kelcey McKinley -- Event Consultant line # 16 Mark Murphy -- Analyst line # 11 Siti Panigrahi -- Analystinstead of the lines returned with unpooled chunking:
line # 256 So, I'm just curious how many customers are showing that interest? What kind of scenarios they can design and, therefore, maybe how to think through the monetization potential at that price point? line # 57 And our first question will come from Meta Marshall with Morgan Stanley. line # 113 But as Eric and I began our conversations over the interview process, I got more and more excited about where I saw Zoom going to an AI-first platform company and could see a lot of the seeds, if you will, of growth being planted and starting to come to fruition. So, got very excited about that. And maybe, you know, my learning sense has been delightful, honestly, to see the customer love and the pace of innovation. I think you'd heard about it before, but to be among it, I think, has been a delight.
Final thoughts
In principal, late chunking is superior to traditional chunking and should be prefered for use cases like the above, where we are interested in finding specific lines from a larger document. Given the increasing context window sizes of the state-of-the-art embedding models (for example NV-Embed-v2), it's feasible to incorporate more and more context into the embeddings of individual lines within a document.
At this time at least, late chunking has yet to be widely adopted, so to implement it you will need to run an open source embedding model and compute the hidden states yourself. Additionally, embedding models are not usually trained to produce their embeddings via the late-chunking pooling method, so ideally one would fine-tune a model using the late-chunking pooling method before using it.
I'm convinced it's worthwhile to invest in exploring late chunking approaches, so I'll be working on a library that abstracts over some of the details.
Happy embedding!
- Liam