UNPKG

ggai

Version:

OpenAI LLM Agent Interface

136 lines (131 loc) 3.95 kB
const { getChatCompletion } = require('./openai.js'); const chalk = require('chalk'); const { calculateCost } = require('./pricing.js'); const ask = require('./ask.js'); class Agent { constructor({ model = 'gpt-4o-mini', temperature = 0, systemPrompt = undefined, historyLimit = -1, config = undefined, tools = undefined, streamingIndicators = false, verbose = false } = {}) { this.model = model; this.systemPrompt = systemPrompt; this.historyLimit = historyLimit; this.config = config; this.tools = tools; this.cost = 0; this.streamingIndicators = streamingIndicators; this.verbose = verbose; this.request = { model, messages: systemPrompt ? [{ role: 'system', content: systemPrompt }] : [], temperature, stream: true, tool_choice: tools ? 'auto' : undefined, tools: tools ? config : undefined, }; } async getResponse() { let messages = this.historyLimit > 0 ? this.request.messages.slice(-this.historyLimit) : this.request.messages; if (messages.length === this.historyLimit) { // Filter out tool responses without associated tool calls let role = messages[0].role; while (role === 'tool') { messages = messages.slice(1); role = messages[0].role; } if (this.systemPrompt && role !== 'system') { messages = [{ role: 'system', content: this.systemPrompt }, ...messages]; } } return await getChatCompletion({ ...this.request, messages }, this.streamingIndicators); } async processToolCalls(toolCalls) { if (this.verbose) { console.log(chalk.red(JSON.stringify(toolCalls, null, 2))); console.log(); } for (const toolCall of toolCalls) { const toolCallId = toolCall.id; const { name, arguments: args } = toolCall.function; try { let input; try { input = JSON.parse(args); } catch (err) { throw new Error('Invalid input'); } console.log(); if (this.verbose) { console.log(); console.log(chalk.magenta(`${name}(${JSON.stringify(input)})`)); } const result = await this.tools[name](input); this.request.messages.push({ role: 'tool', tool_call_id: toolCallId, content: JSON.stringify(result) }); if (this.streamingIndicators) { console.log(); console.log(chalk.dim(`${name}() call complete`)); console.log(); } if (this.verbose) { console.log(chalk.green(JSON.stringify(result, null, 2))); console.log(); } } catch (error) { console.log(`Tool call error: ${error}`); this.request.messages.push({ role: 'tool', tool_call_id: toolCallId, content: `${name} tool call error: ${error}` }); } } } async send(input, skipInput = false) { if (!skipInput) { if (!input) throw new Error('Input is required'); this.request.messages.push({ role: 'user', content: input || '' }); } const data = await this.getResponse(); const { choices, usage } = data; const { prompt_tokens, completion_tokens } = usage; const cost = calculateCost({ promptTokens: prompt_tokens, completionTokens: completion_tokens }, this.model); this.cost += cost; const response = choices[0]; const { content, role, function_call, tool_calls } = response; const message = { role, content, function_call, tool_calls }; this.request.messages.push(message); if (tool_calls) { try { await this.processToolCalls(tool_calls); } catch (err) { console.log(chalk.red('Error processing tool calls')); } return await this.send('', true); } return content; } } module.exports = { Agent, ask }