UNPKG

seekmix

Version:

🔍 A local semantic caching library for Node.js.

370 lines (322 loc) 12.2 kB
const { createClient, SchemaFieldTypes, VectorAlgorithms } = require('redis'); const axios = require('axios'); const { pipeline } = require('@huggingface/transformers'); const log = require('lemonlog')('SeekMix'); class BaseEmbeddingProvider { constructor({model, dimensions} = {}) { this.model = model; this.dimensions = dimensions; } async getEmbeddings(text) { throw new Error('The getEmbeddings method must be implemented by derived classes'); } float32Buffer(arr) { return Buffer.from(new Float32Array(arr).buffer); } } class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { constructor({ model = 'text-embedding-ada-002', dimensions = 1536, apiKey = process.env.OPENAI_API_KEY } = {}) { super({ model, dimensions }); this.openaiClient = axios.create({ baseURL: 'https://api.openai.com/v1', headers: { 'Authorization': `Bearer ${apiKey}`, 'Content-Type': 'application/json' } }); } async getEmbeddings(text) { try { const response = await this.openaiClient.post('/embeddings', { model: this.model, input: text, encoding_format: 'float' }); return response.data.data[0].embedding; } catch (error) { log.error('Error generating embeddings with OpenAI:', error); throw error; } } } class OpenAIEmbedding3Provider extends OpenAIEmbeddingProvider { constructor({ model = 'text-embedding-3-small', dimensions = 1536 } = {}) { super({ model, dimensions }); } } class OpenAIEmbedding3LargeProvider extends OpenAIEmbeddingProvider { constructor({ model = 'text-embedding-3-large', dimensions = 3072 } = {}) { super({ model, dimensions }); } } // Clase para la generación de embeddings con Hugging Face Transformers.js class HuggingfaceProvider extends BaseEmbeddingProvider { constructor({ model = 'Xenova/multilingual-e5-large', dimensions = 1024, dtype = 'q8', pipelineOptions = {} } = {}) { super({ model, dimensions }); this.dtype = dtype; this.pipelineOptions = pipelineOptions; this.extractor = null; this.isInitialized = false; } async initialize() { if (!this.isInitialized) { try { const options = { dtype: this.dtype, ...this.pipelineOptions }; log.info('Initializing Hugging Face pipeline (first initialization may take longer while downloading the model)...'); this.extractor = await pipeline('feature-extraction', this.model, options); this.dimensions = this.extractor.model.config.hidden_size; log.info(`Hugging Face pipeline initialized with model: ${this.model}`); this.isInitialized = true; } catch (error) { log.error(`Error initializing Hugging Face pipeline with model ${this.model}:`, error); throw error; } } } // Obtener embeddings usando Hugging Face Transformers.js async getEmbeddings(text) { try { await this.initialize(); if (!this.extractor) { throw new Error('Hugging Face pipeline not initialized.'); } const output = await this.extractor(text, { pooling: 'mean', normalize: true }); const embeddingsList = output.tolist(); let embedding = null; if (embeddingsList && embeddingsList.length > 0) { if (Array.isArray(embeddingsList[0]) && typeof embeddingsList[0][0] === 'number') { embedding = embeddingsList[0]; } else if (typeof embeddingsList[0] === 'number') { embedding = embeddingsList; } } if (!embedding) { log.error('Unexpected embedding output structure:', embeddingsList); throw new Error('Failed to extract embedding from Hugging Face pipeline output.'); } return embedding; } catch (error) { log.error('Error generating embeddings with Hugging Face:', error); throw error; } } } class SeekMix { constructor({ redisUrl = 'redis://localhost:6379', indexName = 'seekmix:idx', keyPrefix = 'seekmix:', ttl = -1, similarityThreshold = 0.87, dropIndex = false, dropKeys = false, embeddingProvider = null } = {}) { // Crear provider de embeddings si no se proporciona uno this.embeddingProvider = embeddingProvider || new HuggingfaceProvider(); this.options = { redisUrl, indexName, keyPrefix, ttl, similarityThreshold, dropIndex, dropKeys }; // Inicializar el cliente Redis this.redisClient = createClient({ url: this.options.redisUrl, }); } // Conectar al cliente Redis y configurar el índice de vectores async connect() { try { await this.redisClient.connect(); // Initialize HuggingfaceProvider if applicable if (this.embeddingProvider instanceof HuggingfaceProvider) { await this.embeddingProvider.initialize(); } this.options.indexName = this.options.indexName + ':' + this.embeddingProvider.model; this.options.keyPrefix = this.options.keyPrefix + this.embeddingProvider.model + ':'; // Eliminar índice existente si existe if (this.options.dropIndex) { try { await this.redisClient.ft.dropIndex(this.options.indexName); log.info(`Index ${this.options.indexName} deleted`); } catch (error) { if (!error.message.includes('Unknown Index name')) { throw error; } } } // Eliminar todas las claves del prefijo si se solicita if (this.options.dropKeys) { this.dropKeys(); } const indices = await this.redisClient.ft._LIST(); if (!indices.includes(this.options.indexName)) { // Crear un índice vectorial en Redis para búsqueda semántica await this.redisClient.ft.create( this.options.indexName, { '$.vector': { type: SchemaFieldTypes.VECTOR, AS: 'vector', ALGORITHM: VectorAlgorithms.HNSW, TYPE: 'FLOAT32', DIM: this.embeddingProvider.dimensions, DISTANCE_METRIC: 'COSINE' }, '$.text': { type: SchemaFieldTypes.TEXT, AS: 'text', SORTABLE: true }, '$.timestamp': { type: SchemaFieldTypes.NUMERIC, AS: 'timestamp', SORTABLE: true } }, { ON: 'JSON', PREFIX: this.options.keyPrefix } ); log.info(`Index ${this.options.indexName} created successfully`); } else { log.info(`Using existing index: ${this.options.indexName}`); } return true; } catch (error) { log.error('Error connecting to Redis or configuring index:', error); throw error; } } async dropKeys() { try { let cursor = 0; do { const scanResult = await this.redisClient.scan(cursor, { MATCH: `${this.options.keyPrefix}*`, COUNT: 1000 }); cursor = scanResult.cursor; if (scanResult.keys.length > 0) { await this.redisClient.del(scanResult.keys); log.info(`Deleted ${scanResult.keys.length} keys with prefix ${this.options.keyPrefix}`); } } while (cursor !== 0); } catch (error) { log.error('Error deleting keys:', error); } } async disconnect() { return this.redisClient.disconnect(); } async set(query, result) { try { const vector = await this.embeddingProvider.getEmbeddings(query); const timestamp = Date.now(); const key = `${this.options.keyPrefix}${this._generateKey(query)}`; await this.redisClient.json.set(key, '$', { query, result, vector, timestamp, text: query }); // Establecer TTL solo si no es -1 (sin caducidad) if (this.options.ttl !== -1) { await this.redisClient.expire(key, this.options.ttl); } return true; } catch (error) { log.error('Error saving to cache:', error); throw error; } } async get(query) { try { const vector = await this.embeddingProvider.getEmbeddings(query); // Crear un buffer para el vector const queryBuffer = this.embeddingProvider.float32Buffer(vector); // Buscar vector similar usando la sintaxis correcta de KNN const results = await this.redisClient.ft.search( this.options.indexName, '*=>[KNN 1 @vector $BLOB AS score]', { PARAMS: { BLOB: queryBuffer }, SORTBY: 'score', DIALECT: 2, RETURN: ['$.query', '$.result', '$.timestamp', 'score'], } ); if (results.total > 0 && results.documents[0].value.score <= (1 - this.options.similarityThreshold)) { return { query: results.documents[0].value['$.query'], result: results.documents[0].value['$.result'], timestamp: results.documents[0].value['$.timestamp'], score: results.documents[0].value['score'], }; } return null; } catch (error) { log.error('Error searching in cache:', error); return null; } } async invalidateOld(maxAgeInSeconds) { try { const cutoffTime = Date.now() - (maxAgeInSeconds * 1000); // Buscar entradas más antiguas que el tiempo de corte const results = await this.redisClient.ft.search( this.options.indexName, `@timestamp:[0 ${cutoffTime}]`, { LIMIT: { from: 0, size: 1000, }, } ); // Eliminar entradas antiguas const deletePromises = results.documents.map(doc => { return this.redisClient.del(doc.id); }); await Promise.all(deletePromises); return deletePromises.length; } catch (error) { log.error('Error invalidating old cache:', error); throw error; } } _generateKey(text) { return Buffer.from(text).toString('base64').substring(0, 32); } } module.exports = { SeekMix, HuggingfaceProvider, BaseEmbeddingProvider, OpenAIEmbeddingProvider, OpenAIEmbedding3Provider, OpenAIEmbedding3LargeProvider };