graphzep
Version:
GraphZep: A temporal knowledge graph memory system for AI agents based on the Zep paper
455 lines (380 loc) • 13.6 kB
text/typescript
import {
ZepMemory,
ZepSearchParams,
ZepSearchResult,
RerankingStrategy,
RerankingConfig,
} from './types.js';
import { BaseEmbedderClient } from '../embedders/client.js';
import { GraphDriver } from '../types/index.js';
export class ZepRetrieval {
private embedder: BaseEmbedderClient;
private driver: GraphDriver;
constructor(embedder: BaseEmbedderClient, driver: GraphDriver) {
this.embedder = embedder;
this.driver = driver;
}
/**
* Search for memories using hybrid search (semantic + keyword + graph)
*/
async search(params: ZepSearchParams): Promise<ZepSearchResult[]> {
const searchType = params.searchType || 'hybrid';
let results: ZepSearchResult[] = [];
switch (searchType) {
case 'semantic':
results = await this.semanticSearch(params);
break;
case 'keyword':
results = await this.keywordSearch(params);
break;
case 'hybrid':
results = await this.hybridSearch(params);
break;
case 'mmr':
results = await this.mmrSearch(params);
break;
}
// Apply reranking if requested
if (params.rerank) {
results = await this.rerankResults(results, params.query);
}
// Filter by minimum relevance
if (params.minRelevance !== undefined) {
const minRelevance = params.minRelevance;
results = results.filter((r) => r.score >= minRelevance);
}
// Apply limit
if (params.limit) {
results = results.slice(0, params.limit);
}
return results;
}
/**
* Semantic search using embeddings
*/
private async semanticSearch(params: ZepSearchParams): Promise<ZepSearchResult[]> {
const queryEmbedding = await this.embedder.embed(params.query);
let query = `
MATCH (m:ZepMemory)
WHERE m.embedding IS NOT NULL
`;
const whereConditions: string[] = [];
const queryParams: any = { embedding: queryEmbedding };
if (params.sessionId) {
whereConditions.push('m.sessionId = $sessionId');
queryParams.sessionId = params.sessionId;
}
if (params.userId) {
whereConditions.push('m.userId = $userId');
queryParams.userId = params.userId;
}
if (params.memoryTypes && params.memoryTypes.length > 0) {
whereConditions.push('m.memoryType IN $memoryTypes');
queryParams.memoryTypes = params.memoryTypes;
}
if (params.timeRange) {
if (params.timeRange.start) {
whereConditions.push('m.createdAt >= datetime($startTime)');
queryParams.startTime = params.timeRange.start.toISOString();
}
if (params.timeRange.end) {
whereConditions.push('m.createdAt <= datetime($endTime)');
queryParams.endTime = params.timeRange.end.toISOString();
}
}
if (whereConditions.length > 0) {
query += ' AND ' + whereConditions.join(' AND ');
}
query += `
WITH m,
CASE
WHEN size(m.embedding) = 0 OR size($embedding) = 0 THEN 0.0
ELSE
reduce(dot = 0.0, i IN range(0, size(m.embedding)-1) |
dot + (m.embedding[i] * $embedding[i])
) /
(sqrt(reduce(norm1 = 0.0, x IN m.embedding | norm1 + x * x)) *
sqrt(reduce(norm2 = 0.0, x IN $embedding | norm2 + x * x))
)
END AS similarity
WHERE similarity > 0.1
OPTIONAL MATCH (m)-[:HAS_FACT]->(f:ZepFact)
RETURN m, collect(f) as facts, similarity
ORDER BY similarity DESC
LIMIT ${params.limit || 20}
`;
const results = await this.driver.executeQuery<any>(query, queryParams);
return results.map((record: any) => ({
memory: this.parseMemoryFromNode(record.m, record.facts),
score: record.similarity,
distance: 1 - record.similarity,
}));
}
/**
* Keyword search using BM25
*/
private async keywordSearch(params: ZepSearchParams): Promise<ZepSearchResult[]> {
// Use simple text matching instead of fulltext index
let query = `
MATCH (m:ZepMemory)
WHERE toLower(m.content) CONTAINS toLower($query)
`;
const whereConditions: string[] = [];
const queryParams: any = { query: params.query };
if (params.sessionId) {
whereConditions.push('m.sessionId = $sessionId');
queryParams.sessionId = params.sessionId;
}
if (params.userId) {
whereConditions.push('m.userId = $userId');
queryParams.userId = params.userId;
}
if (whereConditions.length > 0) {
query += ' AND ' + whereConditions.join(' AND ');
}
query += `
WITH m, 1.0 as score
OPTIONAL MATCH (m)-[:HAS_FACT]->(f:ZepFact)
RETURN m, collect(f) as facts, score
ORDER BY m.createdAt DESC
LIMIT ${params.limit || 20}
`;
const results = await this.driver.executeQuery<any>(query, queryParams);
return results.map((record: any) => ({
memory: this.parseMemoryFromNode(record.m, record.facts),
score: record.score,
}));
}
/**
* Hybrid search combining semantic and keyword search
*/
private async hybridSearch(params: ZepSearchParams): Promise<ZepSearchResult[]> {
// Get both semantic and keyword results
const [semanticResults, keywordResults] = await Promise.all([
this.semanticSearch(params),
this.keywordSearch(params),
]);
// Apply Reciprocal Rank Fusion (RRF)
return this.reciprocalRankFusion(semanticResults, keywordResults);
}
/**
* Maximal Marginal Relevance (MMR) search for diversity
*/
private async mmrSearch(params: ZepSearchParams): Promise<ZepSearchResult[]> {
// Start with semantic search
const candidates = await this.semanticSearch({
...params,
limit: (params.limit || 10) * 3, // Get more candidates
});
if (candidates.length === 0) return [];
// Apply MMR algorithm
const lambda = 0.5; // Balance between relevance and diversity
const selected: ZepSearchResult[] = [];
const remaining = [...candidates];
// Select first item (most relevant)
selected.push(remaining.shift()!);
// Iteratively select items that maximize MMR
while (selected.length < (params.limit || 10) && remaining.length > 0) {
let bestScore = -Infinity;
let bestIndex = -1;
for (let i = 0; i < remaining.length; i++) {
const candidate = remaining[i];
// Calculate relevance score
const relevanceScore = candidate.score;
// Calculate maximum similarity to already selected items
let maxSimilarity = 0;
for (const selectedItem of selected) {
const similarity = this.calculateSimilarity(
candidate.memory.embedding!,
selectedItem.memory.embedding!,
);
maxSimilarity = Math.max(maxSimilarity, similarity);
}
// MMR score
const mmrScore = lambda * relevanceScore - (1 - lambda) * maxSimilarity;
if (mmrScore > bestScore) {
bestScore = mmrScore;
bestIndex = i;
}
}
if (bestIndex >= 0) {
selected.push(remaining.splice(bestIndex, 1)[0]);
}
}
return selected;
}
/**
* Reciprocal Rank Fusion for combining multiple result sets
*/
private reciprocalRankFusion(...resultSets: ZepSearchResult[][]): ZepSearchResult[] {
const k = 60; // RRF constant
const scoreMap = new Map<string, number>();
const memoryMap = new Map<string, ZepSearchResult>();
for (const results of resultSets) {
results.forEach((result, rank) => {
const uuid = result.memory.uuid;
const rrfScore = 1 / (k + rank + 1);
if (scoreMap.has(uuid)) {
scoreMap.set(uuid, scoreMap.get(uuid)! + rrfScore);
} else {
scoreMap.set(uuid, rrfScore);
memoryMap.set(uuid, result);
}
});
}
// Sort by RRF score
const sortedResults = Array.from(scoreMap.entries())
.sort((a, b) => b[1] - a[1])
.map(([uuid, score]) => {
const result = memoryMap.get(uuid)!;
return { ...result, score };
});
return sortedResults;
}
/**
* Rerank results using cross-encoder or other strategies
*/
private async rerankResults(
results: ZepSearchResult[],
query: string,
config?: RerankingConfig,
): Promise<ZepSearchResult[]> {
const strategy = config?.strategy || RerankingStrategy.RECIPROCAL_RANK_FUSION;
switch (strategy) {
case RerankingStrategy.RECIPROCAL_RANK_FUSION:
// Already applied in hybrid search
return results;
case RerankingStrategy.MAXIMAL_MARGINAL_RELEVANCE:
// Apply MMR post-processing
return this.applyMMRReranking(results, config?.diversityLambda || 0.5);
case RerankingStrategy.GRAPH_BASED:
// Boost results with more graph connections
return this.graphBasedReranking(results);
case RerankingStrategy.CROSS_ENCODER:
// Would require additional model for cross-encoding
console.warn('Cross-encoder reranking not implemented yet');
return results;
default:
return results;
}
}
/**
* Apply MMR reranking to existing results
*/
private applyMMRReranking(results: ZepSearchResult[], lambda: number): ZepSearchResult[] {
if (results.length <= 1) return results;
const reranked: ZepSearchResult[] = [];
const remaining = [...results];
// Select first item
reranked.push(remaining.shift()!);
// Iteratively select diverse items
while (remaining.length > 0) {
let bestScore = -Infinity;
let bestIndex = -1;
for (let i = 0; i < remaining.length; i++) {
const candidate = remaining[i];
// Original relevance score
const relevanceScore = candidate.score;
// Maximum similarity to selected items
let maxSimilarity = 0;
for (const selected of reranked) {
if (candidate.memory.embedding && selected.memory.embedding) {
const similarity = this.calculateSimilarity(
candidate.memory.embedding,
selected.memory.embedding,
);
maxSimilarity = Math.max(maxSimilarity, similarity);
}
}
// MMR score
const mmrScore = lambda * relevanceScore - (1 - lambda) * maxSimilarity;
if (mmrScore > bestScore) {
bestScore = mmrScore;
bestIndex = i;
}
}
if (bestIndex >= 0) {
reranked.push(remaining.splice(bestIndex, 1)[0]);
}
}
return reranked;
}
/**
* Graph-based reranking using graph connections
*/
private async graphBasedReranking(results: ZepSearchResult[]): Promise<ZepSearchResult[]> {
// Boost scores based on graph connectivity
for (const result of results) {
const connectionCount = await this.getConnectionCount(result.memory.uuid);
// Boost score based on connections (logarithmic scale)
result.score *= 1 + Math.log(1 + connectionCount) * 0.1;
}
// Re-sort by boosted scores
return results.sort((a, b) => b.score - a.score);
}
/**
* Get number of graph connections for a memory
*/
private async getConnectionCount(memoryId: string): Promise<number> {
const query = `
MATCH (m:ZepMemory {uuid: $memoryId})
OPTIONAL MATCH (m)-[r]-()
RETURN count(r) as connections
`;
const result = await this.driver.executeQuery<any>(query, { memoryId });
return result[0]?.connections || 0;
}
/**
* Calculate cosine similarity between two embeddings
*/
private calculateSimilarity(embedding1: number[], embedding2: number[]): number {
if (embedding1.length !== embedding2.length) return 0;
let dotProduct = 0;
let norm1 = 0;
let norm2 = 0;
for (let i = 0; i < embedding1.length; i++) {
dotProduct += embedding1[i] * embedding2[i];
norm1 += embedding1[i] * embedding1[i];
norm2 += embedding2[i] * embedding2[i];
}
norm1 = Math.sqrt(norm1);
norm2 = Math.sqrt(norm2);
if (norm1 === 0 || norm2 === 0) return 0;
return dotProduct / (norm1 * norm2);
}
/**
* Parse memory from database node
*/
private parseMemoryFromNode(node: any, facts: any[]): ZepMemory {
return {
uuid: node.properties.uuid,
sessionId: node.properties.sessionId,
userId: node.properties.userId,
content: node.properties.content,
memoryType: node.properties.memoryType,
embedding: node.properties.embedding,
metadata: JSON.parse(node.properties.metadata || '{}'),
createdAt: new Date(node.properties.createdAt),
lastAccessedAt: node.properties.lastAccessedAt
? new Date(node.properties.lastAccessedAt)
: undefined,
accessCount: node.properties.accessCount,
relevanceScore: node.properties.relevanceScore,
summary: node.properties.summary,
validFrom: new Date(node.properties.validFrom),
validUntil: node.properties.validUntil ? new Date(node.properties.validUntil) : undefined,
facts:
facts?.map((f) => ({
uuid: f.properties.uuid,
subject: f.properties.subject,
predicate: f.properties.predicate,
object: f.properties.object,
confidence: f.properties.confidence,
sourceMemoryIds: [node.properties.uuid],
validFrom: new Date(f.properties.validFrom),
validUntil: f.properties.validUntil ? new Date(f.properties.validUntil) : undefined,
metadata: JSON.parse(f.properties.metadata || '{}'),
})) || [],
};
}
}