UNPKG

@astreus-ai/astreus

Version:

AI Agent Framework with Chat Management

675 lines (587 loc) 23.3 kB
import { v4 as uuidv4 } from "uuid"; import { MemoryConfig, MemoryInstance, MemoryEntry, MemoryFactory, SimilaritySearchResult, } from "./types"; import { Embedding } from "./providers"; import { logger } from "./utils"; import { validateRequiredParam, validateRequiredParams } from "./utils/validation"; import { DEFAULT_MEMORY_SIZE } from "./constants"; /** * Memory manager implementation using database storage * Provides storage and retrieval of conversation history and vector embeddings */ export class MemoryManager implements MemoryInstance { public config: MemoryConfig; constructor(config: MemoryConfig) { // Validate required parameters validateRequiredParam(config, "config", "MemoryManager constructor"); validateRequiredParams( config, ["database"], "MemoryManager constructor" ); // Apply defaults for optional config parameters this.config = { ...config, tableName: config.tableName || "memories", maxEntries: config.maxEntries || DEFAULT_MEMORY_SIZE, enableEmbeddings: config.enableEmbeddings || false }; logger.debug("Memory manager initialized"); } /** * Create a new memory instance with proper configuration * @param config Configuration object for memory * @returns Promise that resolves to the new memory instance */ static async create(config: MemoryConfig): Promise<MemoryInstance> { // Validate required parameters validateRequiredParam(config, "config", "MemoryManager.create"); validateRequiredParams( config, ["database"], "MemoryManager.create" ); try { // Apply defaults - use user-provided table name or default const { database } = config; const fullConfig = { ...config, tableName: config.tableName || "memories", // Use user-provided name or default maxEntries: config.maxEntries || DEFAULT_MEMORY_SIZE, enableEmbeddings: config.enableEmbeddings || false }; // Use user's custom table name const { tableName } = fullConfig; // Use enhanced database table management await database.ensureTable(tableName, (table) => { table.string("id").primary(); table.string("agentId").notNullable().index(); table.string("sessionId").notNullable().index(); table.string("userId").nullable().index(); table.string("role").notNullable(); table.text("content").notNullable(); table.timestamp("timestamp").defaultTo(database.knex.fn.now()); table.json("embedding").nullable(); table.json("metadata"); }); // Check if embeddings are enabled and ensure embedding column exists if (fullConfig.enableEmbeddings) { logger.info(`Memory initialized with embedding support using table: ${tableName}`); const hasEmbeddingColumn = await database.knex.schema.hasColumn( tableName, "embedding" ); if (!hasEmbeddingColumn) { logger.warn(`Adding embedding column to memory table: ${tableName}`); await database.knex.schema.table(tableName, (table) => { table.json("embedding"); // Store as JSON to properly represent array structure }); } } logger.info(`Memory created with custom table: ${tableName}`); return new MemoryManager(fullConfig); } catch (error) { logger.error("Error creating memory instance:", error); throw error; } } /** * Add a new memory entry * @param entry Memory entry to add (without id and timestamp) * @returns Promise resolving to the ID of the new entry */ async add(entry: Omit<MemoryEntry, "id" | "timestamp">): Promise<string> { // Validate required parameters validateRequiredParam(entry, "entry", "add"); validateRequiredParams( entry, ["agentId", "sessionId", "role", "content"], "add" ); try { const { database, tableName, enableEmbeddings } = this.config; // Generate ID and timestamp const id = uuidv4(); const timestamp = new Date(); // Create a copy to avoid modifying the original const entryToInsert = { ...entry }; // Validate and normalize the role for OpenRouter compatibility entryToInsert.role = this.validateRole(entryToInsert.role); // Generate embedding if enabled and not provided if (enableEmbeddings && !entryToInsert.embedding && entryToInsert.content) { try { logger.debug(`Generating embedding for memory entry ${id}`); const generatedEmbedding = await Embedding.generateEmbedding(entryToInsert.content); entryToInsert.embedding = generatedEmbedding; logger.debug(`Generated embedding for memory entry ${id} (${generatedEmbedding.length} dimensions)`); } catch (embeddingError) { logger.warn(`Failed to generate embedding for memory entry ${id}:`, embeddingError); // Continue without embedding rather than failing the entire operation } } // Prepare entry for database insertion with proper JSON serialization const dbEntry = { ...entryToInsert, id, timestamp, }; // Ensure embedding is properly serialized if it exists if (entryToInsert.embedding) { // If embedding is already a string but not JSON, handle special cases if (typeof entryToInsert.embedding === "string") { // Handle case where embedding is "[object Object]" (invalid) if (entryToInsert.embedding === "[object Object]") { (dbEntry as any).embedding = null; } else { // Keep it as is since it's already a string (dbEntry as any).embedding = entryToInsert.embedding; } } else if (Array.isArray(entryToInsert.embedding)) { // For database storage, convert arrays to JSON string (dbEntry as any).embedding = JSON.stringify(entryToInsert.embedding); } else { // For unexpected object types, store as null to avoid "[object Object]" logger.warn( `Unexpected embedding type for entry ${id}, setting to null` ); (dbEntry as any).embedding = null; } } // Store in database await database.getTable(tableName!).insert(dbEntry); logger.debug(`Added memory entry ${id}${entryToInsert.embedding ? ' with embedding' : ''} with role: ${entryToInsert.role}`); return id; } catch (error) { logger.error("Error adding memory entry:", error); throw error; } } /** * Get entries by session ID * @param sessionId The session ID to get entries for * @param limit Maximum number of entries to return * @returns Promise resolving to array of memory entries */ async getBySession( sessionId: string, limit?: number ): Promise<MemoryEntry[]> { // Validate required parameters validateRequiredParam(sessionId, "sessionId", "getBySession"); try { const { database, tableName, maxEntries } = this.config; // Query database - using non-null assertion since we set defaults in constructor const entries = await database .knex(tableName!) .where({ sessionId: sessionId }) .orderBy("timestamp", "asc") .limit(limit || maxEntries!); logger.debug(`Retrieved ${entries.length} entries for session ${sessionId}`); // Process embedding data for client use return this.processEntriesBeforeReturn(entries); } catch (error) { logger.error(`Error getting entries for session ${sessionId}:`, error); throw error; } } /** * Get entries by agent ID * @param agentId The agent ID to get entries for * @param limit Maximum number of entries to return * @returns Promise resolving to array of memory entries */ async getByAgent(agentId: string, limit?: number): Promise<MemoryEntry[]> { // Validate required parameters validateRequiredParam(agentId, "agentId", "getByAgent"); try { const { database, tableName, maxEntries } = this.config; // Query database - using non-null assertion since we set defaults in constructor const entries = await database .knex(tableName!) .where({ agentId: agentId }) .orderBy("timestamp", "asc") .limit(limit || maxEntries!); logger.debug(`Retrieved ${entries.length} entries for agent ${agentId}`); // Process embedding data for client use return this.processEntriesBeforeReturn(entries); } catch (error) { logger.error(`Error getting entries for agent ${agentId}:`, error); throw error; } } /** * Get entries by user ID * @param userId The user ID to get entries for * @param limit Maximum number of entries to return * @returns Promise resolving to array of memory entries */ async getByUser(userId: string, limit?: number): Promise<MemoryEntry[]> { // Validate required parameters validateRequiredParam(userId, "userId", "getByUser"); try { const { database, tableName, maxEntries } = this.config; // Query database with userId - using non-null assertion since we set defaults in constructor const entries = await database .knex(tableName!) .where({ userId: userId }) .orderBy("timestamp", "desc") // Most recent first .limit(limit || maxEntries!); logger.debug(`Retrieved ${entries.length} entries for user ${userId}`); // Process embedding data for client use return this.processEntriesBeforeReturn(entries); } catch (error) { logger.error(`Error getting entries for user ${userId}:`, error); throw error; } } /** * Get a memory entry by its ID * @param id The ID of the memory entry to retrieve * @returns Promise resolving to the memory entry or null if not found */ async getById(id: string): Promise<MemoryEntry | null> { // Validate required parameters validateRequiredParam(id, "id", "getById"); try { const { database, tableName } = this.config; // Query database to find entry by ID const entry = await database .knex(tableName!) .where({ id }) .first(); if (!entry) { logger.debug(`Memory entry with ID ${id} not found`); return null; } logger.debug(`Retrieved memory entry with ID ${id}`); // Process embedding data for client use const processedEntries = this.processEntriesBeforeReturn([entry]); return processedEntries[0]; } catch (error) { logger.error(`Error getting memory entry by ID ${id}:`, error); throw error; } } /** * Delete a memory entry by its ID * @param id The ID of the memory entry to delete * @returns Promise that resolves when deletion is complete */ async delete(id: string): Promise<void> { // Validate required parameters validateRequiredParam(id, "id", "delete"); try { const { database, tableName } = this.config; // Delete entry from database const result = await database .knex(tableName!) .where({ id }) .delete(); if (result === 0) { logger.warn(`No memory entry found for deletion with ID ${id}`); } else { logger.debug(`Deleted memory entry with ID ${id}`); } } catch (error) { logger.error(`Error deleting memory entry with ID ${id}:`, error); throw error; } } /** * Helper method to process entries before returning to client * Parses JSON embedding strings back to arrays * @param entries Array of raw entries from database * @returns Processed entries with parsed embeddings */ private processEntriesBeforeReturn(entries: unknown[]): MemoryEntry[] { return entries.map((entry) => { // We need to cast entry to a record type first since it comes from the database const dbEntry = entry as Record<string, unknown>; // Create a new entry to avoid modifying the original const processedEntry: Partial<MemoryEntry> = { id: String(dbEntry.id || ''), agentId: String(dbEntry.agentId || ''), sessionId: String(dbEntry.sessionId || ''), role: this.validateRole(String(dbEntry.role || '')), content: String(dbEntry.content || ''), timestamp: dbEntry.timestamp instanceof Date ? dbEntry.timestamp : new Date(String(dbEntry.timestamp || '')), metadata: dbEntry.metadata as Record<string, unknown> }; // Parse embedding if it exists and is a string if (dbEntry.embedding) { if (typeof dbEntry.embedding === "string") { // Skip "[object Object]" strings which aren't valid JSON if (dbEntry.embedding === "[object Object]") { processedEntry.embedding = undefined; } else { try { processedEntry.embedding = JSON.parse(dbEntry.embedding); } catch (error) { logger.error( `Error parsing embedding for entry ${String(dbEntry.id)}:`, error ); // If parsing fails, remove the embedding processedEntry.embedding = undefined; } } } else if (Array.isArray(dbEntry.embedding)) { processedEntry.embedding = dbEntry.embedding; } } return processedEntry as MemoryEntry; }); } /** * Search memories by text content * @param query Text to search for in memory content field * @param limit Maximum number of results to return * @returns Array of memory entries matching the query */ async searchByText( query: string, limit: number = 10 ): Promise<MemoryEntry[]> { // Validate required parameters validateRequiredParam(query, "query", "searchByText"); try { const { database, tableName } = this.config; if (!query) return []; // Use case-insensitive search with LIKE const results = await database .knex(tableName!) .whereRaw("LOWER(content) LIKE ?", [`%${query.toLowerCase()}%`]) .orderBy("timestamp", "desc") .limit(limit); logger.debug(`Text search for "${query}" found ${results.length} results`); return this.processEntriesBeforeReturn(results); } catch (error) { logger.error("Error searching memories by text:", error); return []; } } /** * Search memories by embedding similarity * @param embedding The embedding vector to search for * @param limit Maximum number of results to return * @param threshold Minimum similarity score (0-1) to include in results * @returns Array of memory entries with similarity scores */ async searchByEmbedding( embedding: number[], limit: number = 5, threshold: number = 0 // Default to returning all results ): Promise<SimilaritySearchResult[]> { // Validate required parameters validateRequiredParam(embedding, "embedding", "searchByEmbedding"); try { const { database, tableName } = this.config; if (!embedding || !embedding.length) { return []; } // Get all memories with embeddings const memories = await database .knex(tableName!) .whereNotNull("embedding") .orderBy("timestamp", "desc") .limit(100); // Reasonable upper limit for comparison // Process embeddings type EntryWithEmbedding = MemoryEntry & { embedding: number[] }; const entriesWithEmbeddings = memories .map((entry) => { let parsedEmbedding: number[] | undefined; if (entry.embedding) { try { if (typeof entry.embedding === "string") { parsedEmbedding = JSON.parse(entry.embedding); } else if (Array.isArray(entry.embedding)) { parsedEmbedding = entry.embedding; } } catch (error) { logger.error( `Error parsing embedding for entry ${entry.id}:`, error ); } } if (parsedEmbedding && Array.isArray(parsedEmbedding)) { const entryAsMemoryEntry = entry as unknown as MemoryEntry; return { ...entryAsMemoryEntry, embedding: parsedEmbedding, } as EntryWithEmbedding; } return null; }) .filter(Boolean) as EntryWithEmbedding[]; // Calculate similarity scores const entriesWithScores = entriesWithEmbeddings.map((entry) => { const similarity = Embedding.calculateSimilarity( embedding, entry.embedding ); const result: SimilaritySearchResult = { ...this.processEntriesBeforeReturn([entry])[0], similarity, }; return result; }); // Filter by threshold and take top results const filteredResults = entriesWithScores .filter((item) => item.similarity >= threshold) .sort((a, b) => b.similarity - a.similarity) .slice(0, limit); logger.debug(`Found ${filteredResults.length} results in embedding search with threshold ${threshold}`); return filteredResults; } catch (error) { logger.error("Error searching memories by embedding:", error); return []; } } /** * Clear all entries for a session * @param sessionId The session ID to clear entries for * @returns Promise that resolves when clearing is complete */ async clear(sessionId: string): Promise<void> { // Validate required parameters validateRequiredParam(sessionId, "sessionId", "clear"); try { const { database, tableName } = this.config; // Delete all entries for the session const result = await database .knex(tableName!) .where({ sessionId: sessionId }) .delete(); logger.debug(`Cleared ${result} entries for session ${sessionId}`); } catch (error) { logger.error(`Error clearing entries for session ${sessionId}:`, error); throw error; } } /** * Generate a summary of the memory for a specific session * @param sessionId The session ID to summarize * @returns Text summary of the conversation */ async summarize(sessionId: string): Promise<string> { // Validate required parameters validateRequiredParam(sessionId, "sessionId", "summarize"); try { const memories = await this.getBySession(sessionId); if (!memories.length) { return "No conversation history available."; } // Just a simple summary for now - could be enhanced with AI summarization const summary = `Conversation with ${memories.length} messages starting at ${memories[0].timestamp.toISOString()}.`; logger.debug(`Generated summary for session ${sessionId}`); return summary; } catch (error) { logger.error(`Error summarizing session ${sessionId}:`, error); throw error; } } /** * Add an entry with embedding in one step * @param entry The memory entry without ID, timestamp, or embedding * @param embedding The embedding vector to add * @returns Promise resolving to the ID of the new entry */ async addWithEmbedding( entry: Omit<MemoryEntry, "id" | "timestamp" | "embedding">, embedding: number[] ): Promise<string> { // Validate required parameters validateRequiredParam(entry, "entry", "addWithEmbedding"); validateRequiredParam(embedding, "embedding", "addWithEmbedding"); validateRequiredParams( entry, ["agentId", "sessionId", "role", "content"], "addWithEmbedding" ); try { const id = await this.add({ ...entry, embedding, }); logger.debug(`Added memory entry ${id} with embedding`); return id; } catch (error) { logger.error("Error adding memory entry with embedding:", error); throw error; } } // Add a helper method to validate role values private validateRole(role: string): MemoryEntry['role'] { const validRoles: MemoryEntry['role'][] = ['system', 'user', 'assistant', 'task_context', 'task_event', 'task_tool', 'task_result']; // For OpenRouter compatibility, convert task_* roles to user if (role.startsWith('task_')) { return 'user'; } return validRoles.includes(role as MemoryEntry['role']) ? (role as MemoryEntry['role']) : 'user'; // Default to user instead of system for better OpenRouter compatibility } /** * List all sessions for a specific agent * @param agentId The agent ID to list sessions for * @param limit Maximum number of sessions to return * @returns Promise resolving to array of session summaries */ async listSessions(agentId: string, limit?: number): Promise<{ sessionId: string; lastMessage?: string; messageCount: number; lastActivity: Date; metadata?: Record<string, unknown>; }[]> { validateRequiredParam(agentId, "agentId", "listSessions"); try { const { database, tableName, maxEntries } = this.config; // Get session summaries with aggregated data const sessions = await database .knex(tableName!) .select('sessionId') .select(database.knex.raw('COUNT(*) as messageCount')) .select(database.knex.raw('MAX(timestamp) as lastActivity')) .select(database.knex.raw('(SELECT content FROM ' + tableName + ' WHERE sessionId = t.sessionId AND agentId = ? ORDER BY timestamp DESC LIMIT 1) as lastMessage', [agentId])) .select(database.knex.raw('(SELECT metadata FROM ' + tableName + ' WHERE sessionId = t.sessionId AND agentId = ? ORDER BY timestamp DESC LIMIT 1) as metadata', [agentId])) .from(tableName + ' as t') .where({ agentId: agentId }) .groupBy('sessionId') .orderBy('lastActivity', 'desc') .limit(limit || maxEntries!); logger.debug(`Retrieved ${sessions.length} sessions for agent ${agentId}`); return sessions.map((session: any) => ({ sessionId: session.sessionId, lastMessage: session.lastMessage || undefined, messageCount: parseInt(session.messageCount) || 0, lastActivity: session.lastActivity instanceof Date ? session.lastActivity : new Date(session.lastActivity), metadata: session.metadata ? (typeof session.metadata === 'string' ? JSON.parse(session.metadata) : session.metadata) : undefined })); } catch (error) { logger.error(`Error listing sessions for agent ${agentId}:`, error); throw error; } } } /** * Factory function to create a new memory instance * @param config Configuration for the memory instance * @returns Promise that resolves to the new memory instance */ export const createMemory: MemoryFactory = async (config: MemoryConfig) => { // Validate required parameters validateRequiredParam(config, "config", "createMemory"); validateRequiredParams( config, ["database"], "createMemory" ); return MemoryManager.create(config); };