@caleblawson/rag
Version:
The Retrieval-Augmented Generation (RAG) module contains document processing and embedding utilities.
159 lines (133 loc) • 4.87 kB
text/typescript
import type { MastraLanguageModel } from '@mastra/core/agent';
import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from '@mastra/core/relevance';
import type { RelevanceScoreProvider } from '@mastra/core/relevance';
import type { QueryResult } from '@mastra/core/vector';
import { Big } from 'big.js';
// Default weights for different scoring components (must add up to 1)
const DEFAULT_WEIGHTS = {
semantic: 0.4,
vector: 0.4,
position: 0.2,
} as const;
type WeightConfig = {
semantic?: number;
vector?: number;
position?: number;
};
interface ScoringDetails {
semantic: number;
vector: number;
position: number;
queryAnalysis?: {
magnitude: number;
dominantFeatures: number[];
};
}
export interface RerankResult {
result: QueryResult;
score: number;
details: ScoringDetails;
}
// For use in the vector store tool
export interface RerankerOptions {
weights?: WeightConfig;
topK?: number;
}
// For use in the rerank function
export interface RerankerFunctionOptions {
weights?: WeightConfig;
queryEmbedding?: number[];
topK?: number;
}
export interface RerankConfig {
options?: RerankerOptions;
model: MastraLanguageModel;
}
// Calculate position score based on position in original list
function calculatePositionScore(position: number, totalChunks: number): number {
return 1 - position / totalChunks;
}
// Analyze query embedding features if needed
function analyzeQueryEmbedding(embedding: number[]): {
magnitude: number;
dominantFeatures: number[];
} {
// Calculate embedding magnitude
const magnitude = Math.sqrt(embedding.reduce((sum, val) => sum + val * val, 0));
// Find dominant features (highest absolute values)
const dominantFeatures = embedding
.map((value, index) => ({ value: Math.abs(value), index }))
.sort((a, b) => b.value - a.value)
.slice(0, 5)
.map(item => item.index);
return { magnitude, dominantFeatures };
}
// Adjust scores based on query characteristics
function adjustScores(score: number, queryAnalysis: { magnitude: number; dominantFeatures: number[] }): number {
const magnitudeAdjustment = queryAnalysis.magnitude > 10 ? 1.1 : 1;
const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
return score * magnitudeAdjustment * featureStrengthAdjustment;
}
// Takes in a list of results from a vector store and reranks them based on semantic, vector, and position scores
export async function rerank(
results: QueryResult[],
query: string,
model: MastraLanguageModel,
options: RerankerFunctionOptions,
): Promise<RerankResult[]> {
let semanticProvider: RelevanceScoreProvider;
if (model.modelId === 'rerank-v3.5') {
semanticProvider = new CohereRelevanceScorer(model.modelId);
} else {
semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
}
const { queryEmbedding, topK = 3 } = options;
const weights = {
...DEFAULT_WEIGHTS,
...options.weights,
};
//weights must add up to 1
const sum = Object.values(weights).reduce((acc: Big, w: number) => acc.plus(w.toString()), new Big(0));
if (!sum.eq(1)) {
throw new Error(`Weights must add up to 1. Got ${sum} from ${weights}`);
}
const resultLength = results.length;
const queryAnalysis = queryEmbedding ? analyzeQueryEmbedding(queryEmbedding) : null;
// Get scores for each result
const scoredResults = await Promise.all(
results.map(async (result, index) => {
// Get semantic score from chosen provider
let semanticScore = 0;
if (result?.metadata?.text) {
semanticScore = await semanticProvider.getRelevanceScore(query, result?.metadata?.text);
}
// Get existing vector score from result
const vectorScore = result.score;
// Get score of vector based on position in original list
const positionScore = calculatePositionScore(index, resultLength);
// Combine scores using weights for each component
let finalScore =
weights.semantic * semanticScore + weights.vector * vectorScore + weights.position * positionScore;
if (queryAnalysis) {
finalScore = adjustScores(finalScore, queryAnalysis);
}
return {
result,
score: finalScore,
details: {
semantic: semanticScore,
vector: vectorScore,
position: positionScore,
...(queryAnalysis && {
queryAnalysis: {
magnitude: queryAnalysis.magnitude,
dominantFeatures: queryAnalysis.dominantFeatures,
},
}),
},
};
}),
);
// Sort by score and take top K
return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
}