UNPKG

@astreus-ai/astreus

Version:

AI Agent Framework with Chat Management

823 lines (714 loc) 26.4 kB
import { ChatConfig, ChatInstance, ChatMetadata, ChatSummary, ChatMessage } from "./types/chat"; import { MemoryEntry, ProviderMessage, StructuredCompletionResponse, TaskConfig, TaskResult, ProviderModel } from "./types"; import { logger } from "./utils/logger"; import { validateRequiredParam } from "./utils/validation"; import { Embedding } from "./providers"; import { DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS } from "./constants"; /** * Chat management system that works with existing memory system * Chat IDs are the same as session IDs for backward compatibility * Can be used standalone or integrated with Agent instances */ export class ChatManager implements ChatInstance { public config: ChatConfig; private tableName: string; private maxChats: number; private autoGenerateTitles: boolean; constructor(config: ChatConfig) { validateRequiredParam(config, "config", "ChatManager constructor"); validateRequiredParam(config.database, "config.database", "ChatManager constructor"); validateRequiredParam(config.memory, "config.memory", "ChatManager constructor"); this.config = config; // Use user-provided table name or default this.tableName = config.tableName || "chats"; this.maxChats = config.maxChats || 50; this.autoGenerateTitles = config.autoGenerateTitles !== false; } async createChat(params: { chatId?: string; userId?: string; agentId: string; title?: string; metadata?: Record<string, unknown>; }): Promise<ChatMetadata> { validateRequiredParam(params.agentId, "params.agentId", "createChat"); const chatId = params.chatId || generateChatId(); const now = new Date(); // Check if chat already exists const existingChat = await this.getChat(chatId); if (existingChat) { return existingChat; } const chatMetadata: ChatMetadata = { id: chatId, title: params.title, userId: params.userId, agentId: params.agentId, status: 'active', createdAt: now, updatedAt: now, messageCount: 0, metadata: params.metadata }; // Insert chat metadata await this.config.database.knex(this.tableName).insert({ id: chatId, title: params.title, userId: params.userId, agentId: params.agentId, status: 'active', createdAt: now, updatedAt: now, lastMessageAt: null, messageCount: 0, lastMessage: null, metadata: params.metadata ? JSON.stringify(params.metadata) : null }); logger.info(`Chat created: ${chatId}`); return chatMetadata; } async createChatWithMessage(params: { chatId?: string; userId?: string; agentId: string; title?: string; metadata?: Record<string, unknown>; // Chat functionality parameters message: string; model: any; // ProviderModel systemPrompt?: string; tools?: any[]; // Plugin[] taskManager?: any; // TaskManagerInstance embedding?: number[]; useTaskSystem?: boolean; temperature?: number; maxTokens?: number; }): Promise<string> { validateRequiredParam(params.agentId, "params.agentId", "createChatWithMessage"); validateRequiredParam(params.message, "params.message", "createChatWithMessage"); validateRequiredParam(params.model, "params.model", "createChatWithMessage"); const chatId = params.chatId || generateChatId(); // Check if chat exists, if not create it const existingChat = await this.getChat(chatId); if (!existingChat) { await this.createChat({ chatId, agentId: params.agentId, userId: params.userId, title: params.title, metadata: params.metadata }); logger.info(`Created new chat: ${chatId} for agent: ${params.agentId}`); } // Process the message return await this.processMessage({ chatId, agentId: params.agentId, userId: params.userId, message: params.message, model: params.model, systemPrompt: params.systemPrompt, tools: params.tools || [], taskManager: params.taskManager, metadata: params.metadata || {}, embedding: params.embedding, useTaskSystem: params.useTaskSystem !== false, temperature: params.temperature || DEFAULT_TEMPERATURE, maxTokens: params.maxTokens || DEFAULT_MAX_TOKENS }); } private async processMessage(params: { chatId: string; agentId: string; userId?: string; message: string; model: any; systemPrompt?: string; tools: any[]; taskManager?: any; metadata: Record<string, unknown>; embedding?: number[]; useTaskSystem: boolean; temperature: number; maxTokens: number; }): Promise<string> { const { chatId, agentId, userId, message, model, systemPrompt, tools, taskManager, metadata, embedding, useTaskSystem, temperature, maxTokens } = params; // Get conversation history from chat system const history = await this.getMessages(chatId); // Add user message to chat await this.addMessage({ chatId, agentId, userId, role: "user", content: message, metadata: { ...metadata, timestamp: new Date().toISOString(), useTaskSystem, conversationLength: history.length } }); // Prepare messages for model const messages: ProviderMessage[] = []; // Add system prompt if (systemPrompt) { messages.push({ role: "system", content: systemPrompt, }); } // Add conversation history history.forEach((entry: MemoryEntry) => { messages.push({ role: entry.role as "system" | "user" | "assistant", content: entry.content, }); }); // Add the new user message messages.push({ role: "user", content: message, }); // Get available tools const availableTools = tools.map((tool) => ({ name: tool.name, description: tool.description, parameters: tool.parameters, })); // If tools are available, add them to the system message if (availableTools.length > 0 && systemPrompt) { const systemMessage = messages.find((msg) => msg.role === "system"); if (systemMessage) { systemMessage.content += `\n\nYou have access to the following tools:\n${JSON.stringify( availableTools, null, 2 )}`; } } let response: string | StructuredCompletionResponse; // Determine if we should use task-based approach if (useTaskSystem && tools.length > 0 && taskManager) { // First, analyze the message to determine if it requires tasks const analysisResponse = await model.complete([ ...messages, { role: "system", content: `You are a task planning assistant. Analyze the user's request and determine if it should be broken down into tasks. If tasks are needed, format your response as a JSON array with this structure: [ { "name": "Task name", "description": "Detailed description of the task", "input": {Any input data for the task} } ] The system will automatically determine which tools and plugins to use based on the task name and description. If no tasks are needed, respond with "NO_TASKS_NEEDED".`, }, ]); // Check if the analysis identified tasks const responseText = typeof analysisResponse === 'string' ? analysisResponse : analysisResponse.content; if (responseText.includes("[") && responseText.includes("]")) { try { // Extract JSON array from the response const jsonString = responseText.substring( responseText.indexOf("["), responseText.lastIndexOf("]") + 1 ); // Parse the tasks const taskConfigs: TaskConfig[] = JSON.parse(jsonString); // Set session ID in task manager if (taskManager.setSessionId) { taskManager.setSessionId(chatId); } // Create and execute tasks for (const taskConfig of taskConfigs) { taskManager.addExistingTask(taskConfig, model); } const taskResults = await taskManager.run(); // Generate response based on task results const resultSummary: any[] = []; for (const [taskId, result] of taskResults.entries()) { resultSummary.push({ taskId, success: result.success, output: result.output, error: result.error ? result.error.message : undefined, }); } const taskResponse = await model.complete([ ...messages, { role: "system", content: `You are assisting with a task-based workflow. Multiple tasks were executed based on the user's request. Here are the results of those tasks: ${JSON.stringify(resultSummary, null, 2)} Analyze these results and generate a helpful, coherent response to the user that summarizes what was done and the outcome. Do not mention that tasks were executed behind the scenes - just provide the information the user needs in a natural way.`, }, ]); response = typeof taskResponse === 'string' ? taskResponse : taskResponse.content; } catch (error) { logger.error("Error processing tasks:", error); // Fallback to standard completion if task processing fails response = await model.complete(messages); } } else { // No tasks needed, just complete the response normally response = await model.complete(messages); } } else { // Regular chat completion response = await model.complete(messages, { tools: availableTools.length > 0 ? availableTools : undefined, toolCalling: availableTools.length > 0, temperature, maxTokens }); } // Handle tool execution if the response contains tool calls if (typeof response === 'object' && response.tool_calls && Array.isArray(response.tool_calls)) { logger.debug(`Chat ${chatId} received ${response.tool_calls.length} tool calls to execute`); const toolResults = []; for (const toolCall of response.tool_calls) { try { if (toolCall.type === 'function' && toolCall.name) { logger.debug(`Executing tool: ${toolCall.name} with arguments:`, toolCall.arguments); // Find the tool to execute const tool = tools.find(t => t.name === toolCall.name); if (tool && tool.execute) { // Execute the tool const result = await tool.execute(toolCall.arguments || {}); toolResults.push({ name: toolCall.name, arguments: toolCall.arguments, result: result, success: true }); logger.debug(`Tool ${toolCall.name} executed successfully`); } else { logger.warn(`Tool ${toolCall.name} not found or not executable`); toolResults.push({ name: toolCall.name, arguments: toolCall.arguments, error: `Tool ${toolCall.name} not found or not executable`, success: false }); } } } catch (error) { logger.error(`Error executing tool ${toolCall.name}:`, error); toolResults.push({ name: toolCall.name, arguments: toolCall.arguments, error: error instanceof Error ? error.message : String(error), success: false }); } } // Generate a final response based on tool results if (toolResults.length > 0) { try { logger.debug(`Generating final response based on ${toolResults.length} tool results`); const toolResultsMessage = { role: "system" as const, content: `You called the following tools and got these results: ${toolResults.map(tr => `Tool: ${tr.name} Arguments: ${JSON.stringify(tr.arguments)} Result: ${tr.success ? JSON.stringify(tr.result) : 'ERROR: ' + tr.error}`).join('\n\n')} Based on these tool results, generate a helpful response to the user. Be natural and conversational - don't mention the technical details of the tool calls.` }; // Call the model again with the tool results to generate the final response const finalResponse = await model.complete([ ...messages, toolResultsMessage ]); // Update response to the final result response = typeof finalResponse === 'string' ? finalResponse : finalResponse.content; const responseText = typeof response === 'string' ? response : response.content; logger.debug(`Generated final response after tool execution: ${responseText.length} characters`); } catch (error) { logger.error('Error generating final response from tool results:', error); // Fallback to original content plus tool results summary const originalContent = typeof response === 'string' ? response : response.content; response = `${originalContent}\n\nTool execution completed with ${toolResults.filter(r => r.success).length} successful results.`; } } } // Generate embedding for assistant response if needed let assistantEmbedding: number[] | undefined = undefined; const enableEmbeddings = this.config.memory.config?.enableEmbeddings; if (enableEmbeddings && embedding) { try { // Convert response to string for embedding generation const responseText = typeof response === 'string' ? response : response.content; assistantEmbedding = await Embedding.generateEmbedding(responseText); } catch (error) { logger.warn( "Error generating embedding for assistant response:", error ); } } // Add assistant response to chat await this.addMessage({ chatId, agentId, userId, role: "assistant", content: typeof response === 'string' ? response : response.content, metadata: { ...metadata, timestamp: new Date().toISOString(), modelUsed: model.name, taskSystemUsed: useTaskSystem, temperature, responseLength: (typeof response === 'string' ? response : response.content).length } }); // Return the response (convert to string if it's a structured response) return typeof response === 'string' ? response : response.content; } async getChat(chatId: string): Promise<ChatMetadata | null> { validateRequiredParam(chatId, "chatId", "getChat"); const result = await this.config.database.knex(this.tableName) .where({ id: chatId }) .first(); if (!result) return null; return { id: result.id, title: result.title, userId: result.userId, agentId: result.agentId, status: result.status, createdAt: new Date(result.createdAt), updatedAt: new Date(result.updatedAt), lastMessageAt: result.lastMessageAt ? new Date(result.lastMessageAt) : undefined, messageCount: result.messageCount, metadata: result.metadata ? JSON.parse(result.metadata) : undefined }; } async updateChat(chatId: string, updates: Partial<ChatMetadata>): Promise<void> { validateRequiredParam(chatId, "chatId", "updateChat"); const updateData: any = { updatedAt: new Date() }; if (updates.title !== undefined) updateData.title = updates.title; if (updates.status !== undefined) updateData.status = updates.status; if (updates.metadata !== undefined) { updateData.metadata = updates.metadata ? JSON.stringify(updates.metadata) : null; } await this.config.database.knex(this.tableName) .where({ id: chatId }) .update(updateData); logger.info(`Chat updated: ${chatId}`); } async deleteChat(chatId: string): Promise<void> { validateRequiredParam(chatId, "chatId", "deleteChat"); // Delete chat metadata await this.config.database.knex(this.tableName) .where({ id: chatId }) .del(); // Clear messages from memory await this.config.memory.clear(chatId); logger.info(`Chat deleted: ${chatId}`); } async archiveChat(chatId: string): Promise<void> { await this.updateChat(chatId, { status: 'archived' }); logger.info(`Chat archived: ${chatId}`); } async listChats(params: { userId?: string; agentId?: string; status?: 'active' | 'archived' | 'deleted'; limit?: number; offset?: number; } = {}): Promise<ChatSummary[]> { let query = this.config.database.knex(this.tableName) .select('*') .orderBy('updatedAt', 'desc'); if (params.userId) { query = query.where({ userId: params.userId }); } if (params.agentId) { query = query.where({ agentId: params.agentId }); } if (params.status) { query = query.where({ status: params.status }); } const limit = params.limit || this.maxChats; const offset = params.offset || 0; query = query.limit(limit).offset(offset); const results = await query; return results.map(row => ({ id: row.id, title: row.title, userId: row.userId, agentId: row.agentId, status: row.status, lastMessage: row.lastMessage, lastMessageAt: row.lastMessageAt ? new Date(row.lastMessageAt) : undefined, messageCount: row.messageCount, createdAt: new Date(row.createdAt), updatedAt: new Date(row.updatedAt) })); } async searchChats(params: { query: string; userId?: string; agentId?: string; limit?: number; }): Promise<ChatSummary[]> { validateRequiredParam(params.query, "params.query", "searchChats"); let query = this.config.database.knex(this.tableName) .select('*') .where('title', 'like', `%${params.query}%`) .orWhere('lastMessage', 'like', `%${params.query}%`) .orderBy('updatedAt', 'desc'); if (params.userId) { query = query.andWhere({ userId: params.userId }); } if (params.agentId) { query = query.andWhere({ agentId: params.agentId }); } const limit = params.limit || this.maxChats; query = query.limit(limit); const results = await query; return results.map(row => ({ id: row.id, title: row.title, userId: row.userId, agentId: row.agentId, status: row.status, lastMessage: row.lastMessage, lastMessageAt: row.lastMessageAt ? new Date(row.lastMessageAt) : undefined, messageCount: row.messageCount, createdAt: new Date(row.createdAt), updatedAt: new Date(row.updatedAt) })); } async addMessage(params: { chatId: string; agentId: string; userId?: string; role: "system" | "user" | "assistant" | "task_context" | "task_event" | "task_tool" | "task_result"; content: string; metadata?: Record<string, unknown>; }): Promise<string> { validateRequiredParam(params.chatId, "params.chatId", "addMessage"); validateRequiredParam(params.agentId, "params.agentId", "addMessage"); validateRequiredParam(params.role, "params.role", "addMessage"); validateRequiredParam(params.content, "params.content", "addMessage"); // Add message to memory (using sessionId = chatId) const messageId = await this.config.memory.add({ agentId: params.agentId, sessionId: params.chatId, // sessionId = chatId for compatibility userId: params.userId, role: params.role, content: params.content, metadata: params.metadata }); // Update chat metadata const now = new Date(); await this.config.database.knex(this.tableName) .where({ id: params.chatId }) .update({ lastMessageAt: now, updatedAt: now, lastMessage: params.content.substring(0, 200), // Store first 200 chars messageCount: this.config.database.knex.raw('messageCount + 1') }); // Auto-generate title if this is the first user message and no title exists if (this.autoGenerateTitles && params.role === 'user') { const chat = await this.getChat(params.chatId); if (chat && !chat.title && chat.messageCount <= 1) { const title = generateTitleFromMessage(params.content); await this.updateChat(params.chatId, { title }); } } return messageId; } async getMessages(chatId: string, limit?: number): Promise<ChatMessage[]> { validateRequiredParam(chatId, "chatId", "getMessages"); // Get messages from memory using sessionId = chatId const messages = await this.config.memory.getBySession(chatId, limit); // Convert to ChatMessage format return messages.map(msg => ({ ...msg, chatId: msg.sessionId // Add chatId for compatibility })); } async deleteMessage(messageId: string): Promise<void> { validateRequiredParam(messageId, "messageId", "deleteMessage"); await this.config.memory.delete(messageId); } async clearMessages(chatId: string): Promise<void> { validateRequiredParam(chatId, "chatId", "clearMessages"); // Clear messages from memory await this.config.memory.clear(chatId); // Reset message count in chat metadata await this.config.database.knex(this.tableName) .where({ id: chatId }) .update({ messageCount: 0, lastMessage: null, lastMessageAt: null, updatedAt: new Date() }); } async getChatStats(params: { userId?: string; agentId?: string; } = {}): Promise<{ totalChats: number; activeChats: number; archivedChats: number; totalMessages: number; }> { let query = this.config.database.knex(this.tableName); if (params.userId) { query = query.where({ userId: params.userId }); } if (params.agentId) { query = query.where({ agentId: params.agentId }); } const stats = await query .select( this.config.database.knex.raw('COUNT(*) as totalChats'), this.config.database.knex.raw('SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as activeChats', ['active']), this.config.database.knex.raw('SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) as archivedChats', ['archived']), this.config.database.knex.raw('SUM(messageCount) as totalMessages') ) .first(); return { totalChats: parseInt(stats.totalChats) || 0, activeChats: parseInt(stats.activeChats) || 0, archivedChats: parseInt(stats.archivedChats) || 0, totalMessages: parseInt(stats.totalMessages) || 0 }; } async chat(params: { message: string; chatId: string; agentId: string; userId?: string; model: any; // ProviderModel systemPrompt?: string; tools?: any[]; // Plugin[] taskManager?: any; // TaskManagerInstance chatTitle?: string; metadata?: Record<string, unknown>; embedding?: number[]; useTaskSystem?: boolean; temperature?: number; maxTokens?: number; }): Promise<string> { validateRequiredParam(params.message, "params.message", "chat"); validateRequiredParam(params.chatId, "params.chatId", "chat"); validateRequiredParam(params.agentId, "params.agentId", "chat"); validateRequiredParam(params.model, "params.model", "chat"); const { message, chatId, agentId, userId, model, systemPrompt, tools = [], taskManager, chatTitle, metadata = {}, embedding, useTaskSystem = true, temperature = DEFAULT_TEMPERATURE, maxTokens = DEFAULT_MAX_TOKENS } = params; // Use createChatWithMessage which handles both creation and message processing return await this.createChatWithMessage({ chatId, agentId, userId, title: chatTitle, metadata, message, model, systemPrompt, tools, taskManager, embedding, useTaskSystem, temperature, maxTokens }); } } /** * Creates a chat management system that works with existing memory system * Chat IDs are the same as session IDs for backward compatibility * Can be used standalone or integrated with Agent instances */ export async function createChat(config: ChatConfig): Promise<ChatInstance> { const chatManager = new ChatManager(config); // Ensure chats table exists using user's custom table name const tableName = config.tableName || "chats"; await config.database.ensureTable(tableName, (table: any) => { table.string('id').primary(); table.string('title').nullable(); table.string('userId').nullable(); table.string('agentId').notNullable(); table.string('status').defaultTo('active'); table.timestamp('createdAt').defaultTo(config.database.knex.fn.now()); table.timestamp('updatedAt').defaultTo(config.database.knex.fn.now()); table.timestamp('lastMessageAt').nullable(); table.integer('messageCount').defaultTo(0); table.string('lastMessage').nullable(); table.text('metadata').nullable(); // Add indexes for better performance table.index(['userId', 'status', 'updatedAt'], 'chats_userid_status_updated_idx'); table.index(['agentId', 'status', 'updatedAt'], 'chats_agentid_status_updated_idx'); table.index(['status', 'updatedAt'], 'chats_status_updated_idx'); }); logger.info(`Chat manager initialized with custom table: ${tableName}`); return chatManager; } // Helper function to generate unique chat ID function generateChatId(): string { return `chat-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; } // Helper function to generate title from first message function generateTitleFromMessage(content: string): string { // Take first 50 characters and clean up let title = content.substring(0, 50).trim(); // Remove line breaks and extra spaces title = title.replace(/\s+/g, ' '); // Add ellipsis if truncated if (content.length > 50) { title += '...'; } return title || 'New Chat'; }