UNPKG

@forge-ml/rag

Version:

A RAG (Retrieval-Augmented Generation) package for Forge ML

102 lines (101 loc) 3.75 kB
import { createClient, SchemaFieldTypes, VectorAlgorithms, } from "redis"; import { VECTOR_MODEL_DIM } from "../types"; const INDEX_KEY = "idx:chunks"; const CHUNK_KEY_PREFIX = `chunks`; //@TODO pass in the user's embedding model and adjust the DIM accordingly const GenericIndex = (dim) => ({ "$.chunkEmbeddings": { type: SchemaFieldTypes.VECTOR, TYPE: "FLOAT32", ALGORITHM: VectorAlgorithms.FLAT, DIM: dim, // this needs to be set to the dimesension set by the embedding model, 3072 for text-embedding-3-large or 1536 for text-embedding-3-small, 768 for nomic v1.5 embedder DISTANCE_METRIC: "L2", AS: "chunkEmbeddings", }, "$.chunkId": { type: SchemaFieldTypes.TEXT, NOSTEM: true, SORTABLE: true, AS: "chunkId", }, "$.documentId": { type: SchemaFieldTypes.TEXT, NOSTEM: true, SORTABLE: true, AS: "documentId", }, }); const defaultCreateIndexOpts = { dim: VECTOR_MODEL_DIM.NOMIC_V1_5, }; class RedisVectorStore { client; constructor(url) { this.client = createClient({ url }); this.client.connect().catch(console.error); } async createIndex(opts = defaultCreateIndexOpts) { try { await this.client.ft.dropIndex(INDEX_KEY).catch(() => { }); } catch (indexErr) { console.error(indexErr); } await this.client.ft.create(INDEX_KEY, GenericIndex(opts.dim), { ON: "JSON", PREFIX: CHUNK_KEY_PREFIX, }); } async addEmbedding(embedding) { return this.client.json.set(`${CHUNK_KEY_PREFIX}:${embedding.chunkId}`, "$", { chunkId: embedding.chunkId.replace(/[^a-zA-Z0-9]/g, "."), documentId: embedding.documentId.replace(/[^a-zA-Z0-9]/g, "."), chunkEmbeddings: embedding.embedding, }); } async storeEmbeddings(embeddings) { await Promise.all(embeddings.map((embedding) => this.addEmbedding(embedding))); } async queryEmbeddings(params) { const results = await this.knnSearchEmbeddings({ inputVector: params.query, k: params.k, documentIds: params.documentIds, }); return results.documents.map((doc) => ({ chunkId: doc.value.chunkId?.toString().replace(/[^a-zA-Z0-9]/g, "-") || "", documentId: doc.value.documentId?.toString().replace(/[^a-zA-Z0-9]/g, "-") || "", score: doc.value.score, })); } async knnSearchEmbeddings({ inputVector, k, documentIds, }) { try { let filter = "*"; if (documentIds && documentIds.length > 0) { const escapedIds = documentIds.map(id => id.replace(/[^a-zA-Z0-9]/g, ".")); filter = `@documentId:(${escapedIds.join('|')})`; } const query = `${filter}=>[KNN ${k} @chunkEmbeddings $searchBlob AS score]`; const searchParams = { PARAMS: { searchBlob: Buffer.from(new Float32Array(inputVector).buffer), }, RETURN: ["score", "chunkId", "documentId"], SORTBY: { BY: "score", }, DIALECT: 2, }; const results = await this.client.ft.search(INDEX_KEY, query, searchParams); if (!results || !results.documents) { throw new Error('No results returned from Redis search'); } return results; } catch (error) { console.error('Error in knnSearchEmbeddings:', error); throw error; } } } export default RedisVectorStore;