UNPKG

@n2flowjs/nbase

Version:

Neural Vector Database for efficient similarity search

232 lines (203 loc) 8.74 kB
import { RerankingOptions, SearchResult } from '../types'; import * as distanceMetrics from '../utils/distance_metrics'; // Import distance functions /** * SearchReranker provides various methods to reorder search results * to improve diversity, relevance, or other custom criteria. */ /** * A utility class for reranking search results using various strategies. * * The SearchReranker provides algorithms to refine initial search results * beyond simple distance/similarity sorting. This can improve result relevance * and user experience by considering factors like diversity or weighted attributes. * * @class SearchReranker * * Supports three reranking strategies: * - `standard`: Preserves the original ranking, optionally limiting to top k results * - `diversity`: Implements Maximal Marginal Relevance (MMR) to balance relevance and diversity * - `weighted`: Adjusts ranking based on weighted metadata attributes * * @example * ```typescript * const reranker = new SearchReranker(); * * // Standard reranking (limit to top 5) * const topResults = reranker.rerank(initialResults, { method: 'standard', k: 5 }); * * // Diversity reranking * const diverseResults = reranker.rerank(initialResults, { * method: 'diversity', * queryVector: query, * vectorsMap: vectors, * lambda: 0.7 * }); * * // Weighted reranking based on metadata * const weightedResults = reranker.rerank(initialResults, { * method: 'weighted', * metadataMap: metadata, * weights: { recency: 0.3, popularity: 0.5 } * }); * ``` */ export class SearchReranker { /** * Rerank search results using the specified method. * This is the main public entry point for reranking. * * @param results The initial list of search results, typically sorted by distance/similarity. * @param options Configuration for the reranking process, including the method to use. * @returns A new list of reranked search results. */ public rerank(results: SearchResult[], options: RerankingOptions = {}): SearchResult[] { const { method = 'standard' } = options; // Default to standard if no method specified // Ensure results is an array before proceeding if (!Array.isArray(results)) { console.error('Reranker received invalid input: results is not an array.'); return []; } // Dispatch to the appropriate private reranking method switch (method) { case 'diversity': console.log('Dispatching to diversity reranking...'); // Debug log return this._diversityReranking(results, options); case 'weighted': console.log('Dispatching to weighted reranking...'); // Debug log return this._weightedReranking(results, options); case 'standard': default: // Fallback to standard reranking console.log('Dispatching to standard reranking (default)...'); // Debug log return this._standardReranking(results, options); } } /** * Basic reranking: Returns the results as is or potentially capped at k. * Does not change the order based on content or metadata. */ private _standardReranking(results: SearchResult[], options: RerankingOptions): SearchResult[] { const { k = results.length } = options; // Simple copy and slice to avoid modifying original results and apply k limit return results.slice(0, k); } /** * Diversity-based reranking using Maximal Marginal Relevance (MMR) concept. * Requires actual vectors for calculation. */ private _diversityReranking(initialResults: SearchResult[], options: RerankingOptions): SearchResult[] { const { k = initialResults.length, queryVector, lambda = 0.7, // Default balance: more towards relevance vectorsMap, distanceMetric = 'euclidean', // Default distance metric } = options; // --- Input Validation --- if (!queryVector || !vectorsMap || vectorsMap.size === 0 || initialResults.length <= 1) { console.warn('Diversity reranking skipped: Missing queryVector, vectorsMap, or insufficient results.'); return initialResults.slice(0, k); // Return original top K } // Add more validation as needed (e.g., lambda range) // --- Setup --- const distanceFunc = distanceMetrics.getDistanceFunction(distanceMetric); const typedQueryVector = queryVector instanceof Float32Array ? queryVector : new Float32Array(queryVector); const remainingResults = new Map<number | string, SearchResult>(); const resultVectors = new Map<number | string, Float32Array>(); initialResults.forEach((res) => { const vector = vectorsMap.get(res.id); if (vector) { remainingResults.set(res.id, res); resultVectors.set(res.id, vector); } else { console.warn(`Vector for result ID ${res.id} not found in vectorsMap. Skipping for diversity rerank.`); } }); if (remainingResults.size === 0) { console.warn('No results with available vectors for diversity reranking.'); return initialResults.slice(0, k); } const finalResults: SearchResult[] = []; const selectedIds = new Set<number | string>(); // --- MMR Algorithm --- // 1. Select the first result let firstResult: SearchResult | null = null; let minInitialDist = Infinity; for (const res of remainingResults.values()) { if (res.dist < minInitialDist) { minInitialDist = res.dist; firstResult = res; } } if (!firstResult) { console.error('Could not determine the first result for MMR.'); return initialResults.slice(0, k); } finalResults.push(firstResult); selectedIds.add(firstResult.id); remainingResults.delete(firstResult.id); // 2. Iteratively select remaining results while (finalResults.length < k && remainingResults.size > 0) { let bestCandidateId: number | string | null = null; let maxMmrScore = -Infinity; for (const [candidateId, candidateResult] of remainingResults.entries()) { const candidateVector = resultVectors.get(candidateId); if (!candidateVector) continue; // Calculate Relevance Score (using similarity proxy from distance) const relevanceScore = 1.0 / (1.0 + candidateResult.dist); // Calculate Diversity Score (Min Distance to Selected) let minDistanceToSelected = Infinity; for (const selectedId of selectedIds) { const selectedVector = resultVectors.get(selectedId); if (selectedVector) { const distToSelected = distanceFunc(candidateVector, selectedVector); minDistanceToSelected = Math.min(minDistanceToSelected, distToSelected); } } const diversityScore = minDistanceToSelected; // Higher is more diverse // Combine scores using lambda const mmrScore = lambda * relevanceScore + (1 - lambda) * diversityScore; if (mmrScore > maxMmrScore) { maxMmrScore = mmrScore; bestCandidateId = candidateId; } } // Add the best candidate found if (bestCandidateId !== null) { const bestResult = remainingResults.get(bestCandidateId)!; finalResults.push(bestResult); selectedIds.add(bestCandidateId); remainingResults.delete(bestCandidateId); } else { console.warn('MMR iteration finished without selecting a candidate.'); break; // No more suitable candidates } } return finalResults; } /** * Weighted reranking based on metadata attributes. * Requires metadataMap in options. */ private _weightedReranking(results: SearchResult[], options: RerankingOptions): SearchResult[] { const { k = results.length, weights = {}, metadataMap } = options; // Use metadataMap from options if (!metadataMap || metadataMap.size === 0 || Object.keys(weights).length === 0) { console.warn('Weighted reranking skipped: Missing metadataMap or weights.'); return results.slice(0, k); // Apply k limit even if not reranking } // Create weighted scores const weightedResults = results.map((result) => { const itemMetadata = metadataMap.get(result.id) || {}; let weightedScore = result.dist; // Start with original distance // Apply weights for (const [key, weight] of Object.entries(weights)) { if (key in itemMetadata && typeof itemMetadata[key] === 'number') { weightedScore -= (itemMetadata[key] as number) * weight; } } return { ...result, weightedScore }; }); // Sort by weighted score and take top k return weightedResults.sort((a, b) => a.weightedScore - b.weightedScore).slice(0, k); } } export default SearchReranker;