Solving the lost context problem in document retrieval with the embed-then-chunk approach
A user asks "What was Berlin's population in 2023?" and your system has a chunk containing "The city had 3.85 million inhabitants" — but with no mention of Berlin. The system fails because it doesn't know "the city" refers to Berlin, which was mentioned in a previous chunk.
This is the lost context problem, and it's a fundamental flaw in how most retrieval systems handle documents. The standard approach involves:
The problem? When we chunk first and embed later, we destroy contextual connections between chunks. Pronouns lose their referents, terminology becomes ambiguous, and discussions that span multiple chunks become fragmented and less retrievable.
Consider this excerpt:
Berlin is the capital and largest city of Germany.
Its 3.85 million inhabitants make it the European Union's most populous city.
The city is also one of Germany's sixteen federal states.
When chunked by sentence, the "Berlin" is separated from details about its population and status. A query about "Berlin's population" would be difficult. "Berlin" would match to the first sentence, where "population" would match to the second sentence. If the model knew that "Its" from chunk 2 represented "Berlin" from chunk 1, the query would match to the second sentence correctly. However, we created the embeddings independently breaking this context link.
This tutorial will introduce you to late chunking to address this issue by changing the traditional "chunk-then-embed" strategy to an "embed-then-chunk" approach to keep the full document context in each chunk embedding.
By the end of this post, you'll:
📚 If you have not implemented a retrieval system before your best bet is to start with a full, but simple retrieval implementation first. Check out this post to get started on that!
🙏 I used a lot of great resources to put this blog post together and it is so amazing these were all available! Please check them out!
- Jina AI's Late Chunking in Long Context Embedding Models blog post
- Jina AI's What Late Chunking Really Is and What it's Not Part II blog post
- Jina AI's Late Chunking Github Repo with their implementation
- Jina AI's Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models paper: Günther, M., Mohr, I., Williams, D. J., Wang, B., & Xiao, H. (2024)
Let's start with a baseline chunking approach. This will serve as pre-requisite knowledge but also give us a baseline to compare to to see what kinds of queries late chunking improves in practice.
The standard workflow used in most retrieval and RAG applications today is:
We will start with a simple implementation using the Sentence Transformers library:
We will need an example text to work with. It is critical that you always have examples to look at constantly as you are working on this stuff. The most common mistake I see, from early exploration and prototyping all the way to production commercial deployments, is not looking at data enough.
document = """
Berlin is the capital and largest city of Germany. The city has a rich history dating back centuries. It was founded in the 13th century and has been a significant cultural and political center throughout European history.
The metropolis experienced dramatic changes during the 20th century, including two world wars and a period of division. After reunification, it underwent extensive reconstruction and modernization efforts.
Its population reached 3.85 million inhabitants in 2023, making it the most populous urban area in the country. This represents a significant increase from previous decades, driven largely by immigration and economic opportunities.
The city is known for its vibrant cultural scene and historical significance. Many tourists visit its famous landmarks each year, contributing significantly to the local economy. The Brandenburg Gate stands as its most iconic symbol.
"""
First, let's do a simple chunking approach. In this example, we will do a word based chunking but with an overlap to account for some conextual information.
def chunk_document(document, chunk_size=50, overlap=10):
"""Split a document into chunks of approximately chunk_size words with overlap."""
words = document.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunks.append(' '.join(words[i:i + chunk_size]))
return chunks
chunks = chunk_document(document, 50, 10)
for i, chunk in enumerate(chunks):
print(f"\nChunk {i+1}:")
print(chunk)
With the traditional approach we would then embed each of those chunks separately
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def traditional_chunking(document, chunk_size=50, overlap=10):
"""Traditional approach: chunk first, then embed each chunk independently."""
# Step 1: Split the document into chunks
chunks = chunk_document(document, chunk_size, overlap)
# Step 2: Embed each chunk independently
chunk_embeddings = model.encode(chunks)
return chunks, chunk_embeddings
chunks, embeddings = traditional_chunking(document)
💡 Other Common Chunking Patterns
While our example uses simple word-based chunking with overlap, several other chunking strategies are popular in practice:
- Sentence-based chunking: Split at sentence boundaries to preserve complete thoughts
- Paragraph-based chunking: Use natural document structure for more coherent chunks
- Fixed token chunking: Count tokens instead of words for more consistent embedding sizes
- Semantic chunking: Group semantically related content using embeddings or topic modeling
- Recursive chunking: Apply hierarchical chunking strategies for nested document structures
- Sliding windows: Create overlapping chunks with a fixed window size and stride
Each approach has trade-offs between implementation complexity, semantic coherence, and retrieval effectiveness. Regardless of the chunking strategy, traditional approaches all share the same fundamental limitation: each chunk is embedded independently without access to the full document context.
We can then do retrieval with those chunks and embeddings.
def simple_retrieval(query, chunks, chunk_embeddings):
"""Retrieve the most relevant chunk for a query."""
# Embed the query
query_embedding = model.encode(query)
# Calculate similarity between query and all chunks
similarities = np.dot(chunk_embeddings, query_embedding)
# Find the most similar chunk
best_match_idx = np.argmax(similarities)
return chunks[best_match_idx], similarities[best_match_idx]
Let's look at an example query
query = "What is Berlin's population?"
best_chunk, similarity = simple_retrieval(query, chunks, embeddings)
print(f"\nQuery: {query}")
print(f"Best matching chunk (similarity: {similarity:.4f}):")
print(best_chunk)
In this example, the approach failed to find the chunk that contains the answer to the query.
The example failed because there is information in 2 chunks that must be put together to answer the question. The search prioritized matching on "Berlin" instead of matching on "population". What we want is a system that is smart enough to understand that the population information from chunk 2 is about Berlin.
This traditional chunking process treats each chunk as an independent document, which means:
Let's look more closely at this problematic behavior.
To demonstrate this problem more clearly, let's run a few more queries against our chunked document:
test_queries = [
"What happened before reunification?",
"What is Berlin's population?",
"When did Berlin reach 3.85 million people?",
"What famous landmark is in Berlin?",
"How many people live in the German capital?",]
for query in test_queries:
best_chunk, similarity = simple_retrieval(query, chunks, embeddings)
print(f"\nQuery: {query}")
print(f"Best matching chunk (similarity: {similarity:.4f}):")
print(best_chunk)
Notice how the system struggles with queries that require connecting information across chunks. For example, when asking about Berlin's population, the system might return a chunk that mentions Berlin but not its population, or vice versa.
This happens because each chunk is embedded in isolation. When chunk 2 mentions "Its population reached 3.85 million inhabitants," the embedding model has no way to know that "Its" refers to Berlin, which was mentioned in chunk 1.
We can visualize this lost context by examining how references get disconnected across chunk boundaries:
Let's make a tiny eval dataset so we can score this method (and later we will add late chunking) against queries that have all the necessary information within a single chunk as well as queries where the necessary information is spread across multiple chunks.
cross_chunk_queries = [
("What is Berlin's population?", "3.85 million inhabitants"),
("How many people live in the German capital?", "3.85 million inhabitants"),
("What famous landmark is in Berlin?", "Brandenburg Gate"),
("What is the city known for?", "vibrant cultural scene"),
("When was it founded?", "13th century")
]
single_chunk_queries = [
("When was Berlin founded?", "13th century"),
("What happened to Berlin during the 20th century?", "two world wars"),
("What's the most iconic symbol", "Brandenburg Gate"),
("What happened in the 20th century?", "two world wars")
]
We can create a small function that lets us evaluate against the dataset. We already retrieved chunks so we can re-use that and see if the answer is in the chunk it returned.
def evaluate_retrieval(queries_with_answers, chunks, embeddings):
"""Evaluate retrieval performance on a set of queries with known answers."""
def process_query(query_answer_pair):
query, answer_text = query_answer_pair
answer_chunks = [i for i, chunk in enumerate(chunks) if answer_text in chunk]
best_chunk, similarity = simple_retrieval(query, chunks, embeddings)
best_chunk_idx = chunks.index(best_chunk)
contains_answer = best_chunk_idx in answer_chunks
return {
'query': query,
'retrieved_chunk': best_chunk_idx + 1,
'answer_in_chunks': [i + 1 for i in answer_chunks],
'contains_answer': contains_answer,
'similarity': similarity}
return [process_query(query_pair) for query_pair in queries_with_answers]
We can print and take a look at the accuracy.
cross_chunk_results = evaluate_retrieval(cross_chunk_queries, chunks, embeddings)
single_chunk_results = evaluate_retrieval(single_chunk_queries, chunks, embeddings)
# Calculate accuracy
cross_chunk_accuracy = sum(r['contains_answer'] for r in cross_chunk_results) / len(cross_chunk_results)
single_chunk_accuracy = sum(r['contains_answer'] for r in single_chunk_results) / len(single_chunk_results)
print(f"Accuracy for queries with answers in a single chunk: {single_chunk_accuracy:.2f}")
print(f"Accuracy for queries with answers spanning chunks: {cross_chunk_accuracy:.2f}")
We've got a problem! 0% accuracy on our queries with answers spanning chunks! Let's explore what the problem is and what to do about it.
A common approach to mitigate this problem is to use overlapping chunks, which we've already incorporated in our example. While overlap helps, it has significant limitations:
While increasing overlap helps, it doesn't solve the fundamental problem: each chunk is still embedded independently without access to the full document context.
These limitations highlight why we need a better approach to chunking and embedding documents for retrieval. The ideal solution would:
This is exactly what late chunking provides. In the next section, we'll explore how late chunking solves these problems by reversing the traditional workflow.
Late chunking flips the traditional approach on its head. Instead of chunking first and then embedding each chunk independently, late chunking follows this workflow:
This seemingly simple change makes a big difference in retrieval quality because each chunk's embedding now contains information from the entire document.
Let's implement late chunking to see it in action
To do chunking we need out a model and tokenizer. We will use Jina AI's which is specifically designed for late chunking.
from transformers import AutoModel, AutoTokenizer
import numpy as np
jina_tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
jina_model = AutoModel.from_pretrained( 'jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
Tokenize our document to convert the document into numeric values
inputs = jina_tokenizer(document, return_tensors='pt', return_offsets_mapping=True)
To work in the token space we need the ids and the offsets to map back to the original text. Let's look at how that works in a dummy example before we start working on actually chunking.
token_ids = inputs['input_ids'][0]
_token_one_id = token_ids[1]
_token_one_id
Token 1 (the second token) is 4068. This is a value we can use to look up an embedding. 4068 represents some text in our original document. Token offsets gives us the mapping to look up what it is representing in the original document
token_offsets = inputs['offset_mapping'][0]
_token_one_offset = token_offsets[1]
_token_one_offset
Our token 4068 has a token offset[1,7]
. We now have enough information to answer "What does 4068 mean in English?". Let's look it up in our original English document.
document[_token_one_offset[0]:_token_one_offset[1]]
Perfect, token 4068
represents Berlin
. We know this because document[1:7]
([1:7]
is the token offsets) returns the text Berlin
.
We can go from word to token ID using the model's tokenizer, which stores a mapping that was learned when the model was trained.
jina_tokenizer.vocab['berlin']
Now that we understand how to work with token offsets to connect tokens in Jina AI's model to words in our English document.
Great! Now we need to tokenize the entire document so we can pass that to a model.
inputs = jina_tokenizer(document, return_tensors='pt')
inputs['input_ids'].shape
We can see that after passing the entire document to the tokenizer, we get back an array with 152 tokens in it.
Now that it's in numeric form, we can pass all the token to the model to get the embeddings for each token.
model_output = jina_model(**inputs)
token_embeddings = model_output[0]
token_embeddings.shape
As we can see, each of the 152 tokens has an embedding with 768 numbers in it.
Each of these represents a word's meaning in the context of this document, not just an isolated meaning of the word itself. That's the genius of this model.
To chunk by sentences we can use period as the seperator. We can use what we learned about the token id and offset to chunk in that token space.
punctuation_mark_id = jina_tokenizer.convert_tokens_to_ids('.')
punctuation_mark_id
We can use that followed by a space to get start and stop of a chunk. Let's do this, and store both the start and stop indexes in both english and token spaces.
chunk_positions, token_span_annotations = [], []
span_start_char, span_start_token = 0,0
for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets)):
if i < len(token_ids)-1:
next_token_id = token_ids[i+1]
if token_id == punctuation_mark_id and document[end:end+1] in [' ', '\n']:
# Store both character positions and token positions start and end
chunk_positions.append( (span_start_char , int(end)))
token_span_annotations.append((span_start_token, i+1))
# Update start positions for next chunk
span_start_char, span_start_token = int(end)+1, i+1
Let's print a few chunks to see that we've got them both for human readable and model readable formats
for i in range(3):
char_start, char_end = chunk_positions[i]
token_start, token_end = token_span_annotations[i]
print(f"Chunk {i}:")
print(f" Character span ({char_start}:{char_end}): {document[char_start:char_end].strip('\n')}")
print(f" Token span ({token_start}:{token_end}): {token_ids[token_start:token_end]}")
print()
All that's left it to use our chunks in token space to chunk the token embeddings.
start_token, end_token = token_span_annotations[0]
chunk_embedding = token_embeddings[0, start_token:end_token]
chunk_embedding.shape
There are 11 tokens each with an embedding of length 768 in this chunk. But we need a single embedding for this chunk. To do that we use mean pooling (just average them).
chunk_embedding = chunk_embedding.mean(dim=0)
chunk_embedding.shape
💡 You may notice that it's not late chunking OR chunk overlap. While we aren't doing overlap with late chunking, you certainly can try. Go ahead and experiment with all the traditional chunking performance tricks with late chunking approaches as well.
Do that in a loop to get embedding for each chunk in the document.
embeddings = []
# For each token span, calculate the mean of its token embeddings
for start_token, end_token in token_span_annotations:
if end_token > start_token: # Ensure span has at least one token
# Mean pooling over the token embeddings for this chunk
chunk_embedding = token_embeddings[0, start_token:end_token].mean(dim=0)
embeddings.append(chunk_embedding.detach().cpu().numpy())
len(embeddings)
Perfect! Now we have our document chunked with an embedding for each chunk. This is late chunking, so let's put all of this in a function so we can do some evaluation and see what kind of impact it made.
def late_chunking(document, model, tokenizer):
"Implements late chunking on a document."
# Tokenize with offset mapping to find sentence boundaries
inputs_with_offsets = tokenizer(document, return_tensors='pt', return_offsets_mapping=True)
token_offsets = inputs_with_offsets['offset_mapping'][0]
token_ids = inputs_with_offsets['input_ids'][0]
# Find chunk boundaries
punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
chunk_positions, token_span_annotations = [], []
span_start_char, span_start_token = 0, 0
for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets)):
if i < len(token_ids)-1:
if token_id == punctuation_mark_id and document[end:end+1] in [' ', '\n']:
# Store both character positions and token positions
chunk_positions.append((span_start_char, int(end)))
token_span_annotations.append((span_start_token, i+1))
# Update start positions for next chunk
span_start_char, span_start_token = int(end)+1, i+1
# Create text chunks from character positions
chunks = [document[start:end].strip() for start, end in chunk_positions]
# Encode the entire document
inputs = tokenizer(document, return_tensors='pt')
model_output = model(**inputs)
token_embeddings = model_output[0]
# Create embeddings for each chunk using mean pooling
embeddings = []
for start_token, end_token in token_span_annotations:
if end_token > start_token: # Ensure span has at least one token
chunk_embedding = token_embeddings[0, start_token:end_token].mean(dim=0)
embeddings.append(chunk_embedding.detach().cpu().numpy())
return chunks, embeddings
late_chunks, late_embeddings = late_chunking(document, jina_model, jina_tokenizer)
late_chunks
We can re-use what we did earlier along with the sample queries we were looking at to run an evaluation.
Remember this eval is to illustrate the kinds of questions late chunking can help solve. A lot more work needs to go into evaluating 2 different approaches like this, so read the paper and blog posts by Jina AI if you are interested in that. For this post, it is purely illustrative to build intuition over the kinds of queries late chunking can help with.
So let's get started. Let's refresh our memory on the document we are querying.
print(document)
For convenience let's concatenate the single chunk queries (everything needed to answer the query is contained in 1 chunk) and cross chunk queries (everything needed to answer the query is spread aross more than 1 chunk)
queries_with_answers = cross_chunk_queries + single_chunk_queries
cross_chunk_queries, single_chunk_queries
We can re-use the all the stuff we covered earlier in the blog post to check if the answer to the question is in the chunk that is returned
def evaluate_traditional_chunking(queries_with_answers, document):
# Create traditional chunks and embeddings
chunks, embeddings = traditional_chunking(document)
results = []
for query, answer_text in queries_with_answers:
# Find which chunks contain the answer
answer_chunks = [i for i, chunk in enumerate(chunks) if answer_text in chunk]
# Get the best matching chunk using simple_retrieval
best_chunk, similarity = simple_retrieval(query, chunks, embeddings)
best_chunk_idx = chunks.index(best_chunk)
# Check if the best chunk contains the answer
contains_answer = best_chunk_idx in answer_chunks
results.append({
'query': query,
'answer': answer_text,
'traditional_correct': contains_answer,
})
return results
# pd.DataFrame(evaluate_traditional_chunking(queries_with_answers, document))
Let's do the same for the late chunking approach so we can compare the two.
def evaluate_late_chunking(queries_with_answers, document):
# Get chunks and embeddings using late chunking
late_chunks, late_embeddings = late_chunking(document, jina_model, jina_tokenizer)
results = []
for query, answer_text in queries_with_answers:
# Find which chunks contain the answer
answer_chunks = [i for i, chunk in enumerate(late_chunks) if answer_text in chunk]
# Embed the query
query_embedding = jina_model.encode(query)
# Find most similar chunk
similarities = [np.dot(query_embedding, chunk_emb) /
(np.linalg.norm(query_embedding) * np.linalg.norm(chunk_emb))
for chunk_emb in late_embeddings]
best_chunk_idx = np.argmax(similarities)
# Check if the best chunk contains the answer
contains_answer = best_chunk_idx in answer_chunks
results.append({
'query': query,
'answer': answer_text,
'late_chunking_correct': contains_answer,
})
return results
# pd.DataFrame(evaluate_late_chunking(queries_with_answers, document))
Let's put both these evaluations into a single dataframe so we can compare and see if our late chunking did better at the cross chunk queries.
# Helper function to merge results
def create_comparison_table(queries_with_answers, document):
# Get results for both methods
trad_results = evaluate_traditional_chunking(queries_with_answers, document)
late_results = evaluate_late_chunking(queries_with_answers, document)
# Combine results
combined_results = []
for trad, late in zip(trad_results, late_results):
combined_results.append({
'query': trad['query'],
'answer': trad['answer'],
'traditional_correct': trad['traditional_correct'],
'late_chunking_correct': late['late_chunking_correct'],
'query_type': 'cross_chunk' if trad['query'] in [q for q, _ in cross_chunk_queries] else 'single_chunk'
})
# Create DataFrame
return pd.DataFrame(combined_results)
# Evaluate both query types
comparison_df = create_comparison_table(queries_with_answers, document)
comparison_df
The results show that late chunking maintains high accuracy for single-chunk queries while improving performance on cross-chunk queries. While these are example queries and not a full proper eval, you can use this to build intuition about why late chunking works.
Late chunking solves the lost context problem in several important ways:
Bidirectional context awareness: Each token embedding is influenced by all other tokens in the document, both before and after it. This means references like "the city" can be properly linked to "Berlin" mentioned earlier.
Consistent representation: All chunks from the same document share the same contextual foundation, ensuring that related concepts are represented similarly regardless of which chunk they appear in.
Preservation of long-range dependencies: Information from the beginning of a document can influence the representation of content at the end, maintaining semantic connections across the entire text.
Resilience to boundary selection: Since each token's embedding already contains document-wide context, the specific chunking boundaries become less critical. This means simpler chunking strategies can work just as well as complex ones.
Late chunking requires embedding models that can handle long contexts—ideally 8K tokens or more. These models aren't just standard embedding models with longer input windows; they're specifically designed to maintain coherent representations across thousands of tokens.
The key advantages of these long-context models for late chunking include:
Without these capabilities, late chunking wouldn't be possible or effective.
Get notified about new posts on AI, web development, and tech insights.