UNPKG

@astreus-ai/astreus

Version:

AI Agent Framework with Chat Management

789 lines (689 loc) 29 kB
import { v4 as uuidv4 } from "uuid"; import { VectorRAGConfig, VectorRAGInstance, Document, Chunk, RAGResult, VectorRAGFactory, VectorDatabaseType, VectorDatabaseConfig, } from "../types"; import { logger } from "../utils"; import { validateRequiredParam, validateRequiredParams } from "../utils/validation"; import { createVectorDatabaseConnector, VectorDatabaseConnector } from "./vector-db"; import { DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_VECTOR_SIMILARITY_THRESHOLD, DEFAULT_MAX_RESULTS } from "../constants"; /** * Vector-based RAG implementation * Uses vector embeddings for semantic search and document chunking */ export class VectorRAG implements VectorRAGInstance { public config: VectorRAGConfig; private vectorDatabaseConnector: VectorDatabaseConnector; constructor(config: VectorRAGConfig) { // Validate required parameters validateRequiredParam(config, "config", "VectorRAG constructor"); validateRequiredParams( config, ["database"], "VectorRAG constructor" ); // Apply defaults for optional config parameters this.config = { ...config, tableName: config.tableName || "rag_chunks", maxResults: config.maxResults || DEFAULT_MAX_RESULTS, chunkSize: config.chunkSize || DEFAULT_CHUNK_SIZE, chunkOverlap: config.chunkOverlap || DEFAULT_CHUNK_OVERLAP, vectorDatabase: config.vectorDatabase || { type: VectorDatabaseType.SAME_AS_MAIN }, }; // Ensure vectorDatabase config has the correct tableName if (this.config.vectorDatabase && this.config.vectorDatabase.type === VectorDatabaseType.SAME_AS_MAIN) { this.config.vectorDatabase.options = { ...this.config.vectorDatabase.options, tableName: this.config.tableName }; } // Initialize vector database connector this.vectorDatabaseConnector = createVectorDatabaseConnector( this.config.vectorDatabase, this.config.database ); logger.debug("Vector RAG system initialized"); } /** * Create a new vector RAG instance with proper configuration * @param config Configuration object for vector RAG * @returns Promise that resolves to the new vector RAG instance */ static async create(config: VectorRAGConfig): Promise<VectorRAGInstance> { // Validate required parameters validateRequiredParam(config, "config", "VectorRAG.create"); validateRequiredParams( config, ["database"], "VectorRAG.create" ); try { const instance = new VectorRAG(config); await instance.initializeDatabase(); return instance; } catch (error) { logger.error("Error creating vector RAG instance:", error); throw error; } } /** * Initialize database tables for vector RAG */ private async initializeDatabase(): Promise<void> { try { const { database } = this.config; // If using external vector database, only create minimal tables in main DB if (this.config.vectorDatabase?.type !== VectorDatabaseType.SAME_AS_MAIN) { // For external vector databases, we only need association tracking const associationsTableName = 'rag_chunk_associations'; await database.ensureTable(associationsTableName, (table) => { table.string("chunkId").primary(); table.string("documentId").notNullable().index(); table.integer("chunkIndex").notNullable(); table.timestamp("createdAt").defaultTo(database.knex.fn.now()); }); logger.info(`Vector RAG using external vector database - created ${associationsTableName} table`); } else { // Using same database for vectors - create all necessary tables const documentsTableName = 'rag_documents'; const chunksTableName = this.config.tableName || 'rag_chunks'; // Create documents table await database.ensureTable(documentsTableName, (table) => { table.string("id").primary(); table.text("content").notNullable(); table.json("metadata").notNullable(); table.timestamp("createdAt").defaultTo(database.knex.fn.now()); }); // Create chunks table with user's custom name await database.ensureTable(chunksTableName, (table) => { table.string("id").primary(); table.string("documentId").notNullable().index(); table.text("content").notNullable(); table.json("metadata").notNullable(); table.json("embedding").notNullable(); table.timestamp("createdAt").defaultTo(database.knex.fn.now()); // Add foreign key constraint table.foreign("documentId").references("id").inTable(documentsTableName).onDelete("CASCADE"); }); // Update config to use the resolved table name this.config.tableName = chunksTableName; logger.info(`Vector RAG initialized with custom tables: ${documentsTableName} and ${chunksTableName}`); } logger.debug("Vector RAG database initialized"); } catch (error) { logger.error("Error initializing vector RAG database:", error); throw error; } } /** * Add a document to the vector RAG system * The document will be chunked and each chunk will be stored with its embedding * @param document The document to add * @returns Promise resolving to the document ID */ async addDocument(document: Omit<Document, "id" | "embedding">): Promise<string> { // Validate required parameters validateRequiredParam(document, "document", "addDocument"); validateRequiredParams( document, ["content", "metadata"], "addDocument" ); try { const { database } = this.config; const documentId = uuidv4(); // Only store document in main DB if using same database for vectors if (this.config.vectorDatabase?.type === VectorDatabaseType.SAME_AS_MAIN) { // Store the document in main database await database.knex("rag_documents").insert({ id: documentId, content: document.content, metadata: JSON.stringify(document.metadata), createdAt: new Date(), }); logger.debug(`Stored document ${documentId} in main database`); } else { // For external vector DB, document metadata is stored with each chunk logger.debug(`Using external vector DB - document ${documentId} will be stored with chunks`); } // Create chunks from document content const chunks = this.chunkDocument(documentId, document); // Generate embeddings for chunks and store them // Prefer provider over memory for embedding generation if (this.config.provider && this.config.provider.generateEmbedding) { // Process chunks in batches to avoid overloading const batchSize = 10; for (let i = 0; i < chunks.length; i += batchSize) { const batch = chunks.slice(i, i + batchSize); await Promise.all( batch.map(async (chunk) => { await this.storeChunkWithEmbedding(chunk); }) ); } logger.debug(`Added ${chunks.length} chunks with provider-based embeddings for document ${documentId}`); } else if (this.config.memory && this.config.memory.searchByEmbedding) { // Fallback to memory if provider is not available const batchSize = 10; for (let i = 0; i < chunks.length; i += batchSize) { const batch = chunks.slice(i, i + batchSize); await Promise.all( batch.map(async (chunk) => { await this.storeChunkWithEmbedding(chunk); }) ); } logger.debug(`Added ${chunks.length} chunks with memory-based embeddings for document ${documentId}`); } else { logger.warn("No provider or memory with embedding support provided, chunks stored without embeddings"); // Store chunks without embeddings only if using same database if (this.config.vectorDatabase?.type === VectorDatabaseType.SAME_AS_MAIN) { for (const chunk of chunks) { const chunkId = uuidv4(); await database.knex(this.config.tableName!).insert({ id: chunkId, documentId: chunk.documentId, content: chunk.content, metadata: JSON.stringify(chunk.metadata), embedding: JSON.stringify([]), // Empty embedding createdAt: new Date(), }); } } else { logger.warn("Cannot store chunks without embeddings in external vector database"); } } return documentId; } catch (error) { logger.error("Error adding document to vector RAG:", error); throw error; } } /** * Split a document into chunks * @param documentId The ID of the document * @param document The document to chunk * @returns Array of chunks */ private chunkDocument( documentId: string, document: Omit<Document, "id" | "embedding"> ): Omit<Chunk, "id" | "embedding">[] { const { chunkSize, chunkOverlap } = this.config; const text = document.content; const chunks: Omit<Chunk, "id" | "embedding">[] = []; // Simple chunking by characters with overlap for (let i = 0; i < text.length; i += (chunkSize! - chunkOverlap!)) { // Stop if we've reached the end of the text if (i >= text.length) break; // Extract chunk content with overlap const chunkContent = text.substring(i, i + chunkSize!); // Skip empty chunks if (!chunkContent.trim()) continue; // Create chunk with metadata chunks.push({ documentId, content: chunkContent, metadata: { ...document.metadata, chunk_index: chunks.length, start_char: i, end_char: Math.min(i + chunkSize!, text.length), }, }); } return chunks; } /** * Store a chunk with its embedding * @param chunk The chunk to store * @returns Promise resolving to the chunk ID */ private async storeChunkWithEmbedding( chunk: Omit<Chunk, "id" | "embedding"> ): Promise<string> { const { database, tableName, provider, memory } = this.config; try { // Get embedding using provider or memory let embedding: number[] = []; // Prefer provider over memory for embedding generation if (provider && provider.generateEmbedding) { try { // Use the first part of the chunk to generate embedding // Limit to 8000 characters to avoid token limits const textForEmbedding = chunk.content.substring(0, 8000); // Use provider directly for embedding generation const embeddingResult = await provider.generateEmbedding(textForEmbedding); embedding = embeddingResult || []; logger.debug(`Generated embedding using provider (${embedding.length} dimensions)`); } catch (embeddingError) { logger.warn("Failed to generate embedding using provider, trying memory fallback:", embeddingError); // Fallback to memory if provider fails if (memory && memory.searchByEmbedding) { try { const textForEmbedding = chunk.content.substring(0, 8000); const { Embedding } = await import("../providers"); embedding = await Embedding.generateEmbedding(textForEmbedding); logger.debug(`Generated embedding using memory fallback (${embedding.length} dimensions)`); } catch (memoryError) { logger.warn("Failed to generate embedding using memory fallback:", memoryError); embedding = []; } } else { embedding = []; } } } else if (memory && memory.searchByEmbedding) { try { // Use the first part of the chunk to generate embedding // Limit to 8000 characters to avoid token limits const textForEmbedding = chunk.content.substring(0, 8000); // Use the Embedding utility directly instead of memory to avoid complexity const { Embedding } = await import("../providers"); embedding = await Embedding.generateEmbedding(textForEmbedding); logger.debug(`Generated embedding using memory (${embedding.length} dimensions)`); } catch (embeddingError) { logger.warn("Failed to generate embedding using memory:", embeddingError); embedding = []; } } else { logger.warn("No provider or memory with embedding support provided for chunk storage"); } // Generate a unique ID for the chunk const chunkId = uuidv4(); // Get chunk index from metadata const chunkIndex = chunk.metadata.chunk_index || 0; // Store chunk with embedding in the vector database await this.vectorDatabaseConnector.addVectors([ { id: chunkId, vector: embedding, metadata: { ...chunk.metadata, documentId: chunk.documentId, content: chunk.content } } ]); // If using external vector database, keep track of the association if (this.config.vectorDatabase?.type !== VectorDatabaseType.SAME_AS_MAIN) { // Store association between document and chunk await database.knex("rag_chunk_associations").insert({ chunkId, documentId: chunk.documentId, chunkIndex, createdAt: new Date() }); logger.debug(`Added chunk association for chunk ${chunkId} to document ${chunk.documentId}`); } return chunkId; } catch (error) { logger.error("Error storing chunk with embedding:", error); throw error; } } /** * Get a document by its ID * @param id The document ID to retrieve * @returns Promise resolving to the document or null if not found */ async getDocumentById(id: string): Promise<Document | null> { // Validate required parameters validateRequiredParam(id, "id", "getDocumentById"); try { const { database } = this.config; // If using same database, get document from rag_documents table if (this.config.vectorDatabase?.type === VectorDatabaseType.SAME_AS_MAIN) { // Get document from database const document = await database.knex("rag_documents") .where("id", id) .first(); if (!document) { return null; } return { id: document.id, content: document.content, metadata: JSON.parse(document.metadata), }; } else { // For external vector DB, reconstruct document from chunks const chunkAssociations = await database.knex("rag_chunk_associations") .where("documentId", id) .orderBy("chunkIndex", "asc") .select("chunkId"); if (!chunkAssociations || chunkAssociations.length === 0) { return null; } // Get chunk data from vector database let content = ""; let metadata = {}; for (const association of chunkAssociations) { try { const chunkMetadata = await this.vectorDatabaseConnector.getVectorMetadata(association.chunkId); if (chunkMetadata && chunkMetadata.content) { content += chunkMetadata.content; // Use metadata from first chunk as document metadata if (Object.keys(metadata).length === 0 && chunkMetadata.metadata) { metadata = chunkMetadata.metadata; } } } catch (error) { logger.warn(`Error getting chunk ${association.chunkId} for document ${id}:`, error); } } if (!content) { return null; } return { id, content, metadata, }; } } catch (error) { logger.error("Error getting document by ID:", error); throw error; } } /** * Delete a document and its chunks * @param id The document ID to delete */ async deleteDocument(id: string): Promise<void> { // Validate required parameters validateRequiredParam(id, "id", "deleteDocument"); try { const { database, tableName } = this.config; // Handle deletion based on vector database type if (this.config.vectorDatabase?.type === VectorDatabaseType.SAME_AS_MAIN) { // Using same database - delete chunks and document const chunks = await database.knex(tableName!) .where("documentId", id) .select("id"); // Delete chunks if (chunks.length > 0) { const chunkIds = chunks.map((c: any) => c.id); await database.knex(tableName!) .whereIn("id", chunkIds) .delete(); logger.debug(`Deleted ${chunkIds.length} chunks for document ${id}`); } // Delete document (foreign key constraint will handle cascade) await database.knex("rag_documents") .where("id", id) .delete(); logger.debug(`Deleted document ${id} from main database`); } else { // Using external vector database const chunkAssociations = await database.knex("rag_chunk_associations") .where("documentId", id) .select("chunkId"); if (chunkAssociations && chunkAssociations.length > 0) { const chunkIds = chunkAssociations.map((c: any) => c.chunkId); // Delete vectors from the vector database await this.vectorDatabaseConnector.deleteVectors(chunkIds); // Delete associations from the main database await database.knex("rag_chunk_associations") .where("documentId", id) .delete(); logger.debug(`Deleted ${chunkIds.length} vectors from external vector database for document ${id}`); } else { logger.warn(`No chunks found for document ${id} in external vector database`); } } logger.debug(`Successfully deleted document ${id}`); } catch (error) { logger.error("Error deleting document:", error); throw error; } } /** * Search for similar documents using text query * @param query The search query * @param limit Maximum number of results to return * @returns Promise resolving to array of search results */ async search(query: string, limit?: number): Promise<RAGResult[]> { // Validate required parameters validateRequiredParam(query, "query", "search"); try { const maxResults = limit || this.config.maxResults || 10; // Use semantic search if provider or memory with embedding support is available if (this.config.provider && this.config.provider.generateEmbedding) { try { logger.debug(`Generating embedding for search query using provider: "${query}"`); // Generate embedding for query using provider directly const queryEmbedding = await this.config.provider.generateEmbedding(query); if (queryEmbedding && queryEmbedding.length > 0) { logger.debug(`Generated embedding with ${queryEmbedding.length} dimensions`); // Use embedding for semantic search return this.searchByVector(queryEmbedding, maxResults); } else { logger.warn("Provider returned empty embedding, falling back to keyword search"); return this.searchByKeyword(query, maxResults); } } catch (embeddingError) { logger.warn("Failed to generate embedding using provider, trying memory fallback:", embeddingError); // Fallback to memory if provider fails if (this.config.memory && this.config.memory.searchByEmbedding) { try { logger.debug(`Generating embedding for search query using memory fallback: "${query}"`); const { Embedding } = await import("../providers"); const queryEmbedding = await Embedding.generateEmbedding(query); logger.debug(`Generated embedding with ${queryEmbedding.length} dimensions`); return this.searchByVector(queryEmbedding, maxResults); } catch (memoryError) { logger.warn("Failed to generate embedding using memory fallback, using keyword search:", memoryError); return this.searchByKeyword(query, maxResults); } } else { // Fall back to keyword search logger.debug("No memory fallback available, using keyword search"); return this.searchByKeyword(query, maxResults); } } } else if (this.config.memory && this.config.memory.searchByEmbedding) { try { logger.debug(`Generating embedding for search query using memory: "${query}"`); // Generate embedding for query using Embedding utility directly const { Embedding } = await import("../providers"); const queryEmbedding = await Embedding.generateEmbedding(query); logger.debug(`Generated embedding with ${queryEmbedding.length} dimensions`); // Use embedding for semantic search return this.searchByVector(queryEmbedding, maxResults); } catch (embeddingError) { logger.warn("Failed to generate embedding using memory, falling back to keyword search:", embeddingError); // Fall back to keyword search return this.searchByKeyword(query, maxResults); } } else { // Fall back to keyword search logger.debug("No provider or memory with embedding support, using keyword search"); return this.searchByKeyword(query, maxResults); } } catch (error) { logger.error("Error searching in vector RAG:", error); throw error; } } /** * Search for documents using vector similarity * @param embedding The query embedding vector * @param limit Maximum number of results to return * @param threshold Minimum similarity threshold (0-1) * @returns Promise resolving to an array of search results */ async searchByVector( embedding: number[], limit?: number, threshold: number = DEFAULT_VECTOR_SIMILARITY_THRESHOLD ): Promise<RAGResult[]> { try { const maxResults = limit || this.config.maxResults || 10; logger.debug(`Searching by vector with ${embedding.length} dimensions, limit: ${maxResults}`); // Use the vector database connector to search for similar vectors const similarVectors = await this.vectorDatabaseConnector.searchVectors( embedding, maxResults, threshold ); logger.debug(`Found ${similarVectors.length} similar vectors`); // If using the same database, we need to fetch the detailed data if (this.config.vectorDatabase?.type === VectorDatabaseType.SAME_AS_MAIN) { const { database, tableName } = this.config; logger.debug(`Fetching chunk details from table: ${tableName}`); // Get detailed information for each chunk const chunkDetails = await database.knex(tableName!) .whereIn( "id", similarVectors.map((result) => result.id) ) .select("*"); logger.debug(`Retrieved ${chunkDetails.length} chunk details`); // Map to the expected result format return chunkDetails.map((chunk: any) => { const similarityResult = similarVectors.find((v) => v.id === chunk.id); try { const metadata = typeof chunk.metadata === 'string' ? JSON.parse(chunk.metadata) : chunk.metadata; return { content: chunk.content, metadata: metadata, similarity: similarityResult?.similarity || 0, sourceId: chunk.id, }; } catch (parseError) { logger.warn(`Error parsing metadata for chunk ${chunk.id}, using empty object:`, parseError); return { content: chunk.content, metadata: {}, similarity: similarityResult?.similarity || 0, sourceId: chunk.id, }; } }); } else { // For external vector databases, we need to map the results // from the vector IDs to the actual content const { database } = this.config; // If we have no results, return empty array if (similarVectors.length === 0) { return []; } // Get document IDs and content from chunk associations table // This assumes we've stored the content and metadata in the vector database const results: RAGResult[] = []; for (const vector of similarVectors) { try { // For external vector DB, content and metadata are stored in the vector's metadata // This is handled through the vector database connector const vectorMetadata = await this.vectorDatabaseConnector.getVectorMetadata(vector.id); if (!vectorMetadata) { logger.warn(`Metadata for vector ${vector.id} not found, skipping result`); continue; } // The metadata from the vector database should include: // - content: The text content of the chunk // - metadata: The original metadata object of the chunk // - documentId: The ID of the parent document results.push({ content: vectorMetadata.content, metadata: vectorMetadata.metadata, similarity: vector.similarity, sourceId: vector.id, }); } catch (error) { logger.warn(`Error processing vector result ${vector.id}, skipping:`, error); continue; } } return results; } } catch (error) { logger.error("Error searching by vector:", error); throw error; } } /** * Search for documents using keyword matching * @param query The search query * @param limit Maximum number of results to return * @returns Promise resolving to an array of search results */ private async searchByKeyword( query: string, limit?: number ): Promise<RAGResult[]> { const { database, tableName } = this.config; const maxResults = limit || this.config.maxResults || 10; // Simple LIKE query for keyword search const results = await database.knex(tableName!) .whereRaw("LOWER(content) LIKE ?", [`%${query.toLowerCase()}%`]) .limit(maxResults) .select("*"); return results.map((result: any) => { try { const metadata = typeof result.metadata === 'string' ? JSON.parse(result.metadata) : result.metadata; return { content: result.content, metadata: metadata, sourceId: result.id, }; } catch (parseError) { logger.warn(`Error parsing metadata for chunk ${result.id}, using empty object:`, parseError); return { content: result.content, metadata: {}, sourceId: result.id, }; } }); } /** * Calculate cosine similarity between two vectors * @param vecA First vector * @param vecB Second vector * @returns Similarity score between 0 and 1 */ private calculateCosineSimilarity(vecA: number[], vecB: number[]): number { if (vecA.length !== vecB.length) { throw new Error("Vector dimensions do not match"); } let dotProduct = 0; let normA = 0; let normB = 0; for (let i = 0; i < vecA.length; i++) { dotProduct += vecA[i] * vecB[i]; normA += vecA[i] * vecA[i]; normB += vecB[i] * vecB[i]; } normA = Math.sqrt(normA); normB = Math.sqrt(normB); if (normA === 0 || normB === 0) { return 0; } return dotProduct / (normA * normB); } } /** * Factory function to create a VectorRAG instance * @param config Configuration for the vector RAG * @returns Promise resolving to a configured VectorRAG instance */ export const createVectorRAG: VectorRAGFactory = async ( config: VectorRAGConfig ) => { return VectorRAG.create(config); };