UNPKG

@lamemind/react-agent-ts

Version:

Streaming ReAct agent in typescript with multiple LLM providers

165 lines (164 loc) 6.43 kB
import { AIMessageChunksCollector } from "./chunk-collector.js"; import { userMessage } from "./messages.js"; const MAX_ITERATIONS = 10; export class ReActAgent { constructor(model, tools, maxIterations = MAX_ITERATIONS) { this.toolsMap = {}; this.interrupted = false; this.isToolCallsComplete = false; this._duringRestore_isToolCallRequest = false; this.onStateChangeCallback = null; this.iteration = 0; this.model = model.bindTools(tools); this.tools = tools; this.tools.forEach(tool => { this.toolsMap[tool.name] = tool; }); this.maxIterations = maxIterations; this.messages = []; if (tools.length > 0) console.log(`Tools: ${tools.map(tool => tool.name).join(", ")}`); } onStateChange(callback) { this.onStateChangeCallback = callback; } async notifyStateChange() { if (this.onStateChangeCallback) { this.interrupted = await this.onStateChangeCallback(this.saveState()); if (this.interrupted) console.log("\nConversazione interrotta dall'utente."); } } saveState() { return { messages: [...this.messages], completed: this.isToolCallsComplete, llmResponse: this.isToolCallsComplete ? this.extractAiTextResponse() : undefined, iteration: this.iteration }; } async invokeState(state) { this.messages = [...state.messages]; this.isToolCallsComplete = state.completed; this.iteration = state.iteration; const lastMessage = this.messages[this.messages.length - 1]; const lastSubMessage = lastMessage.role === 'assistant' ? lastMessage.content[lastMessage.content.length - 1] : null; this._duringRestore_isToolCallRequest = lastSubMessage && lastSubMessage.type === 'tool_use'; return this.run(); } async invokeMessage(messages) { if (typeof messages === "string") { this.messages.push(userMessage(messages)); } else { this.messages = messages; } return this.run(); } /** * Estrae l'ultima risposta dell'assistente, concatenando i messaggi di tipo text ed escludendo il resto (tool call, ecc.) * @returns L'ultimo messaggio di tipo text dell'assistente */ extractAiTextResponse() { // 1. Trova l'ultimo messaggio dell'utente, possono essere più di uno // 2. Trova tutti i messaggi role=assistant successivi // 3. Concatena i contenuti di tipo text di questi messaggi const reversed = [...this.messages].reverse(); const userMeggageIndex = reversed.findIndex((m) => m.role === "user"); const assistantMessages = reversed .slice(0, userMeggageIndex) .filter((m) => m.role === "assistant") .reverse(); const assistantTextMessage = assistantMessages .flatMap((m) => m.content) .filter((c) => c.type === "text") .map((c) => c.text) .join("\n"); return assistantTextMessage; } async run() { this.interrupted = false; this.isToolCallsComplete = false; while (!this.isToolCallsComplete && this.iteration < this.maxIterations) { this.iteration++; if (this._duringRestore_isToolCallRequest) this._duringRestore_isToolCallRequest = false; else { await this.callLLM(); if (this.interrupted) break; } await this.callTools(); if (this.interrupted) break; } if (this.iteration >= this.maxIterations) console.warn(`\nRaggiunto il numero massimo di iterazioni (${this.maxIterations})`); // this.dumpConversation(); return this.saveState(); } async callLLM() { const stream = await this.model.stream(this.messages); const collector = new AIMessageChunksCollector(); await collector.consume(stream); const llmMessage = collector.formatMessage(); this.messages.push(llmMessage); await this.notifyStateChange(); } async callTools() { const lastMessage = this.messages[this.messages.length - 1]; const tool_calls = lastMessage.content.filter((c) => c.type === 'tool_use'); if (!tool_calls || tool_calls.length <= 0) { this.isToolCallsComplete = true; return; } let counter = 1; for (const call of tool_calls) { console.log(`Call #${counter++} ${call.name} ${JSON.stringify(call.input, null, 2)}`); const result = await this.executeToolCall(call); const toolResult = { role: "tool", tool_call_id: call.id, content: JSON.stringify(result) }; this.messages.push(toolResult); this.notifyStateChange(); if (this.interrupted) return; } } async executeToolCall(call) { const tool = this.toolsMap[call.name]; if (!tool) throw new Error(`Tool non trovato: ${call.name}`); try { const result = await tool.call(call.input); return result; } catch (error) { console.error(`Errore durante l'esecuzione del tool ${call.name}: ${error}`); return `@@@@ Errore durante l'esecuzione del tool ${call.name}: ${error} @@@@`; } } dumpConversation() { console.log("\nConversazione completa:"); console.log("-----------------------"); this.messages.forEach((msg, i) => { if (msg.role === "user") { console.log(`[${i}] User: ${msg.content}`); } else if (msg.role === "assistant") { console.log(`[${i}] Assistant: ${JSON.stringify(msg.content)}...`); } else if (msg.role === "tool") { const tool_call_id = msg.tool_call_id; console.log(`[${i}] Tool (${tool_call_id}): ${msg.content}`); } else console.log(`[${i}] Messaggio sconosciuto: ${JSON.stringify(msg)}`); }); console.log("-----------------------"); } }