ggai
Version:
OpenAI LLM Agent Interface
136 lines (131 loc) • 3.95 kB
JavaScript
const { getChatCompletion } = require('./openai.js');
const chalk = require('chalk');
const { calculateCost } = require('./pricing.js');
const ask = require('./ask.js');
class Agent {
constructor({
model = 'gpt-4o-mini',
temperature = 0,
systemPrompt = undefined,
historyLimit = -1,
config = undefined,
tools = undefined,
streamingIndicators = false,
verbose = false
} = {}) {
this.model = model;
this.systemPrompt = systemPrompt;
this.historyLimit = historyLimit;
this.config = config;
this.tools = tools;
this.cost = 0;
this.streamingIndicators = streamingIndicators;
this.verbose = verbose;
this.request = {
model,
messages: systemPrompt ? [{ role: 'system', content: systemPrompt }] : [],
temperature,
stream: true,
tool_choice: tools ? 'auto' : undefined,
tools: tools ? config : undefined,
};
}
async getResponse() {
let messages = this.historyLimit > 0 ? this.request.messages.slice(-this.historyLimit) : this.request.messages;
if (messages.length === this.historyLimit) {
// Filter out tool responses without associated tool calls
let role = messages[0].role;
while (role === 'tool') {
messages = messages.slice(1);
role = messages[0].role;
}
if (this.systemPrompt && role !== 'system') {
messages = [{ role: 'system', content: this.systemPrompt }, ...messages];
}
}
return await getChatCompletion({
...this.request,
messages
}, this.streamingIndicators);
}
async processToolCalls(toolCalls) {
if (this.verbose) {
console.log(chalk.red(JSON.stringify(toolCalls, null, 2)));
console.log();
}
for (const toolCall of toolCalls) {
const toolCallId = toolCall.id;
const { name, arguments: args } = toolCall.function;
try {
let input;
try {
input = JSON.parse(args);
} catch (err) {
throw new Error('Invalid input');
}
console.log();
if (this.verbose) {
console.log();
console.log(chalk.magenta(`${name}(${JSON.stringify(input)})`));
}
const result = await this.tools[name](input);
this.request.messages.push({
role: 'tool',
tool_call_id: toolCallId,
content: JSON.stringify(result)
});
if (this.streamingIndicators) {
console.log();
console.log(chalk.dim(`${name}() call complete`));
console.log();
}
if (this.verbose) {
console.log(chalk.green(JSON.stringify(result, null, 2)));
console.log();
}
} catch (error) {
console.log(`Tool call error: ${error}`);
this.request.messages.push({
role: 'tool',
tool_call_id: toolCallId,
content: `${name} tool call error: ${error}`
});
}
}
}
async send(input, skipInput = false) {
if (!skipInput) {
if (!input) throw new Error('Input is required');
this.request.messages.push({ role: 'user', content: input || '' });
}
const data = await this.getResponse();
const { choices, usage } = data;
const { prompt_tokens, completion_tokens } = usage;
const cost = calculateCost({
promptTokens: prompt_tokens,
completionTokens: completion_tokens
}, this.model);
this.cost += cost;
const response = choices[0];
const { content, role, function_call, tool_calls } = response;
const message = {
role,
content,
function_call,
tool_calls
};
this.request.messages.push(message);
if (tool_calls) {
try {
await this.processToolCalls(tool_calls);
} catch (err) {
console.log(chalk.red('Error processing tool calls'));
}
return await this.send('', true);
}
return content;
}
}
module.exports = {
Agent, ask
}