UNPKG

@astreus-ai/astreus

Version:

AI Agent Framework with Chat Management

664 lines (573 loc) 20 kB
import { v4 as uuidv4 } from "uuid"; import { AgentConfig, AgentInstance, AgentFactory, Plugin, ProviderModel, ProviderInstance, MemoryInstance, ChatInstance, ChatMetadata, ChatSummary } from "./types"; import { createDatabase } from "./database"; import { PluginManager } from "./plugin"; import { validateRequiredParams, validateRequiredParam } from "./utils/validation"; import { logger } from "./utils/logger"; import { createRAGTools } from "./utils/rag-tools"; import { DEFAULT_AGENT_NAME } from "./constants"; // Agent implementation class Agent implements AgentInstance { public id: string; public config: AgentConfig; private memory: MemoryInstance; // Replace any with MemoryInstance private tools: Map<string, Plugin>; private chatManager?: ChatInstance; constructor(config: AgentConfig) { // Validate required parameters validateRequiredParam(config, "config", "Agent constructor"); validateRequiredParams( config, ["memory"], // 'name' is optional now since we have a default "Agent constructor" ); // Ensure either model or provider is specified if (!config.model && !config.provider) { throw new Error("Either 'model' or 'provider' must be specified in agent config"); } // If provider is given but model is not, use default model from provider if (config.provider && !config.model) { const defaultModelName = config.provider.getDefaultModel?.() || config.provider.listModels()[0]; if (defaultModelName) { config.model = config.provider.getModel(defaultModelName); } else { throw new Error("No default model available in provider"); } } // Ensure we have a model at this point if (!config.model) { throw new Error("No model could be determined for the agent"); } // Set default values for optional parameters this.id = config.id || uuidv4(); this.config = { ...config, name: config.name || DEFAULT_AGENT_NAME, description: config.description || `Agent ${config.name || DEFAULT_AGENT_NAME}`, tools: config.tools || [], plugins: config.plugins || [] }; this.memory = config.memory; this.tools = new Map(); this.chatManager = config.chat; // Initialize tools if provided if (this.config.tools) { this.config.tools.forEach((tool) => { this.tools.set(tool.name, tool); }); } // Create RAG tools if RAG instance is provided if (this.config.rag) { const ragTools = createRAGTools(this.config.rag); ragTools.forEach((tool) => { this.tools.set(tool.name, tool); }); logger.debug(`Added ${ragTools.length} RAG tools to agent ${this.config.name}`); } // Initialize plugins and register their tools if provided if (this.config.plugins) { for (const plugin of this.config.plugins) { // Check if plugin has getTools method (PluginInstance) if (plugin && 'getTools' in plugin && typeof plugin.getTools === 'function') { const pluginTools = plugin.getTools(); if (pluginTools && Array.isArray(pluginTools)) { pluginTools.forEach((tool: Plugin) => { if (tool && tool.name) { this.tools.set(tool.name, tool); // Also register with the global registry PluginManager.register(tool); } }); } } // Check if it's a direct Plugin object else if (plugin && 'name' in plugin && plugin.name && 'execute' in plugin) { // This is already a tool/plugin, register it directly const toolPlugin = plugin as Plugin; this.tools.set(toolPlugin.name, toolPlugin); PluginManager.register(toolPlugin); } } } } // Helper method to safely get the model getModel(): ProviderModel { if (!this.config.model) { throw new Error("No model specified for agent"); } return this.config.model; } // Get the provider instance getProvider(): ProviderInstance | undefined { return this.config.provider; } // Memory access methods async getHistory(sessionId: string, limit?: number): Promise<any[]> { validateRequiredParam(sessionId, "sessionId", "getHistory"); return await this.memory.getBySession(sessionId, limit); } async clearHistory(sessionId: string): Promise<void> { validateRequiredParam(sessionId, "sessionId", "clearHistory"); await this.memory.clear(sessionId); } async addToMemory(params: { sessionId: string; role: 'user' | 'assistant' | 'system'; content: string; metadata?: Record<string, unknown>; }): Promise<string> { validateRequiredParam(params.sessionId, "params.sessionId", "addToMemory"); validateRequiredParam(params.role, "params.role", "addToMemory"); validateRequiredParam(params.content, "params.content", "addToMemory"); return await this.memory.add({ agentId: this.id, sessionId: params.sessionId, role: params.role, content: params.content, metadata: params.metadata || {} }); } // List all sessions for this agent async listSessions(limit?: number): Promise<{ sessionId: string; lastMessage?: string; messageCount: number; lastActivity: Date; metadata?: Record<string, unknown>; }[]> { try { // Get all sessions from memory for this agent const sessions = await this.memory.listSessions(this.id, limit); return sessions.map((session: any) => ({ sessionId: session.sessionId, lastMessage: session.lastMessage || session.content, messageCount: session.messageCount || 1, lastActivity: session.lastActivity || session.createdAt || new Date(), metadata: session.metadata || {} })); } catch (error) { logger.error(`Error listing sessions for agent ${this.id}:`, error); return []; } } // Chat method without streaming async chat(params: { message: string; sessionId?: string; systemPrompt?: string; temperature?: number; maxTokens?: number; metadata?: Record<string, unknown>; }): Promise<string> { validateRequiredParam(params.message, "params.message", "chat"); const { message, sessionId = `session-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`, systemPrompt = this.config.systemPrompt, temperature = 0.7, maxTokens = 2000, metadata = {} } = params; // Get conversation history const history = sessionId ? await this.getHistory(sessionId) : []; // Prepare messages for the model const messages = [ ...(systemPrompt ? [{ role: 'system' as const, content: systemPrompt }] : []), ...history.map((msg: any) => ({ role: (msg.role === 'user' ? 'user' : 'assistant') as 'user' | 'assistant', content: msg.content })), { role: 'user' as const, content: message } ]; // Get response from model const model = this.getModel(); const response = await model.complete(messages, { temperature, maxTokens }); const responseContent = typeof response === 'string' ? response : response.content; // Save to memory await this.addToMemory({ sessionId, role: 'user', content: message, metadata }); await this.addToMemory({ sessionId, role: 'assistant', content: responseContent, metadata }); return responseContent; } // Streaming chat method async streamChat(params: { message: string; sessionId?: string; systemPrompt?: string; temperature?: number; maxTokens?: number; metadata?: Record<string, unknown>; onChunk?: (chunk: string) => void; }): Promise<string> { validateRequiredParam(params.message, "params.message", "streamChat"); const { message, sessionId = `session-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`, systemPrompt = this.config.systemPrompt, temperature = 0.7, maxTokens = 2000, metadata = {}, onChunk } = params; // Get conversation history const history = sessionId ? await this.getHistory(sessionId) : []; // Prepare messages for the model const messages = [ ...(systemPrompt ? [{ role: 'system' as const, content: systemPrompt }] : []), ...history.map((msg: any) => ({ role: (msg.role === 'user' ? 'user' : 'assistant') as 'user' | 'assistant', content: msg.content })), { role: 'user' as const, content: message } ]; let fullResponse = ''; const model = this.getModel(); const provider = this.getProvider(); // Try to get OpenAI client for real streaming const openaiClient = (provider as any)?.client || (model as any)?.client; if (openaiClient && openaiClient.chat && openaiClient.chat.completions) { logger.debug(`Agent ${this.config.name}: Using real OpenAI streaming`); // Use OpenAI streaming directly for incremental chunks const stream = await openaiClient.chat.completions.create({ model: model.name || 'gpt-4o-mini', messages: messages, stream: true, temperature, max_tokens: maxTokens }); for await (const chunk of stream) { const content = chunk.choices[0]?.delta?.content || ''; if (content) { fullResponse += content; if (onChunk) { onChunk(content); // Send only the new chunk } } } } else if (model.complete) { logger.debug(`Agent ${this.config.name}: Using simulated streaming`); // Fallback to complete method with simulated streaming const response = await model.complete(messages, { temperature, maxTokens }); const responseContent = typeof response === 'string' ? response : response.content; fullResponse = responseContent; if (onChunk) { // Simulate streaming by sending word by word const words = responseContent.split(' '); for (let i = 0; i < words.length; i++) { const word = words[i] + (i < words.length - 1 ? ' ' : ''); onChunk(word); // Small delay for realistic streaming effect await new Promise(resolve => setTimeout(resolve, 50)); } } } else { throw new Error('No suitable model method available for streaming'); } // Save to memory await this.addToMemory({ sessionId, role: 'user', content: message, metadata }); await this.addToMemory({ sessionId, role: 'assistant', content: fullResponse, metadata }); return fullResponse; } /** * Get available tool names * @returns Array of tool names available to the agent */ getAvailableTools(): string[] { return Array.from(this.tools.keys()); } addTool(tool: Plugin): void { // Validate required parameters validateRequiredParam(tool, "tool", "addTool"); validateRequiredParams( tool, ["name", "description", "execute"], "addTool" ); this.tools.set(tool.name, tool); } /** * Get the chat manager instance if available */ getChatManager(): ChatInstance | undefined { return this.chatManager; } /** * Set or update the chat manager instance */ setChatManager(chatManager: ChatInstance): void { this.chatManager = chatManager; } // Chat management methods async createChat(params: { chatId?: string; userId?: string; title?: string; metadata?: Record<string, unknown>; }): Promise<ChatMetadata> { if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } return await this.chatManager.createChat({ chatId: params.chatId, userId: params.userId, agentId: this.id, title: params.title, metadata: params.metadata }); } async getChat(chatId: string): Promise<ChatMetadata | null> { validateRequiredParam(chatId, "chatId", "getChat"); if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } return await this.chatManager.getChat(chatId); } async updateChat(chatId: string, updates: Partial<ChatMetadata>): Promise<void> { validateRequiredParam(chatId, "chatId", "updateChat"); if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } await this.chatManager.updateChat(chatId, updates); } async deleteChat(chatId: string): Promise<void> { validateRequiredParam(chatId, "chatId", "deleteChat"); if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } await this.chatManager.deleteChat(chatId); } async archiveChat(chatId: string): Promise<void> { validateRequiredParam(chatId, "chatId", "archiveChat"); if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } await this.chatManager.archiveChat(chatId); } async listChats(params?: { userId?: string; status?: 'active' | 'archived' | 'deleted'; limit?: number; offset?: number; }): Promise<ChatSummary[]> { if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } return await this.chatManager.listChats({ ...params, agentId: this.id }); } async searchChats(params: { query: string; userId?: string; limit?: number; }): Promise<ChatSummary[]> { validateRequiredParam(params.query, "params.query", "searchChats"); if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } return await this.chatManager.searchChats({ ...params, agentId: this.id }); } async getChatStats(params?: { userId?: string; }): Promise<{ totalChats: number; activeChats: number; archivedChats: number; totalMessages: number; }> { if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } return await this.chatManager.getChatStats({ ...params, agentId: this.id }); } // Enhanced chat methods with chat ID support async chatWithId(params: { message: string; chatId: string; userId?: string; systemPrompt?: string; temperature?: number; maxTokens?: number; metadata?: Record<string, unknown>; }): Promise<string> { validateRequiredParam(params.message, "params.message", "chatWithId"); validateRequiredParam(params.chatId, "params.chatId", "chatWithId"); if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } // Check if chat exists, if not create it const existingChat = await this.chatManager.getChat(params.chatId); if (!existingChat) { await this.chatManager.createChat({ chatId: params.chatId, userId: params.userId, agentId: this.id, metadata: params.metadata }); } return await this.chatManager.chat({ message: params.message, chatId: params.chatId, agentId: this.id, userId: params.userId, model: this.getModel(), systemPrompt: params.systemPrompt || this.config.systemPrompt, tools: Array.from(this.tools.values()), metadata: params.metadata, temperature: params.temperature, maxTokens: params.maxTokens }); } async streamChatWithId(params: { message: string; chatId: string; userId?: string; systemPrompt?: string; temperature?: number; maxTokens?: number; metadata?: Record<string, unknown>; onChunk?: (chunk: string) => void; }): Promise<string> { validateRequiredParam(params.message, "params.message", "streamChatWithId"); validateRequiredParam(params.chatId, "params.chatId", "streamChatWithId"); if (!this.chatManager) { throw new Error("Chat manager not configured for this agent"); } // Check if chat exists, if not create it const existingChat = await this.chatManager.getChat(params.chatId); if (!existingChat) { await this.chatManager.createChat({ chatId: params.chatId, userId: params.userId, agentId: this.id, metadata: params.metadata }); } // For now, use the regular chat method and simulate streaming // This can be enhanced later with true streaming support in ChatManager const response = await this.chatManager.chat({ message: params.message, chatId: params.chatId, agentId: this.id, userId: params.userId, model: this.getModel(), systemPrompt: params.systemPrompt || this.config.systemPrompt, tools: Array.from(this.tools.values()), metadata: params.metadata, temperature: params.temperature, maxTokens: params.maxTokens }); // Simulate streaming by calling onChunk with the full response if (params.onChunk) { params.onChunk(response); } return response; } } // Agent factory function export const createAgent: AgentFactory = async (config: AgentConfig) => { // Validate required parameters validateRequiredParam(config, "config", "createAgent"); validateRequiredParams( config, ["memory"], "createAgent" ); // Ensure either model or provider is specified if (!config.model && !config.provider) { throw new Error("Either 'model' or 'provider' must be specified in agent config"); } // Create a new agent instance const agent = new Agent(config); // Save agent to database try { // Use database from config if provided, otherwise create a new one const db = config.database || await createDatabase(); const tableNames = db.getTableNames(); // Ensure agents table exists await db.ensureTable(tableNames.agents, (table) => { table.string("id").primary(); table.string("name").notNullable(); table.text("description").nullable(); table.text("systemPrompt").nullable(); table.string("modelName").notNullable(); table.timestamp("createdAt").defaultTo(db.knex.fn.now()); table.timestamp("updatedAt").defaultTo(db.knex.fn.now()); table.json("configuration").nullable(); }); const agentsTable = db.getTable(tableNames.agents); // Check if agent already exists const existingAgent = await agentsTable.findOne({ id: agent.id }); if (!existingAgent) { // Save new agent await agentsTable.insert({ id: agent.id, name: agent.config.name, description: agent.config.description || null, systemPrompt: agent.config.systemPrompt || null, modelName: agent.config.model?.name || "unknown", createdAt: new Date(), updatedAt: new Date(), configuration: JSON.stringify({ hasTools: agent.getAvailableTools().length > 0, supportsTaskSystem: true, }), }); logger.agent(agent.config.name, `Agent saved to database with ID: ${agent.id}`); } else { // Update existing agent await agentsTable.update( { id: agent.id }, { name: agent.config.name, description: agent.config.description || null, systemPrompt: agent.config.systemPrompt || null, modelName: agent.config.model?.name || "unknown", updatedAt: new Date(), configuration: JSON.stringify({ hasTools: agent.getAvailableTools().length > 0, supportsTaskSystem: true, }), } ); logger.agent(agent.config.name, `Agent updated in database with ID: ${agent.id}`); } } catch (error) { logger.error("Error saving agent to database:", error); } return agent; };