openai-agents
Version:
A TypeScript library extending the OpenAI Node.js SDK for building highly customizable agents and simplifying 'function calling'. Easily create and manage tools to extend LLM capabilities.
435 lines (434 loc) • 17.7 kB
JavaScript
import OpenAI from 'openai';
import { ChatCompletionError, DirectoryAccessError, FileImportError, FileReadError, FunctionCallError, InvalidToolError, MessageValidationError, RedisConnectionError, StorageError, ToolCompletionError, ToolNotFoundError, RedisKeyValidationError, ValidationError, } from './errors';
import { importToolFunctions, loadToolsDirFunctions, ToolsRegistry, } from './modules/tools-registry';
import { getCompletionsUsage, handleNResponses } from './utils';
import { AgentStorage } from './storage';
/**
* @class OpenAIAgent
* @description Extends the OpenAI API client to manage chat completions, tool interactions, and persistent storage of conversation history. Provides methods for creating chat completions, managing tools, and interacting with a Redis storage.
*/
export class OpenAIAgent extends OpenAI {
static REQUIRED_ENV_VARS = ['OPENAI_API_KEY'];
completionParams;
defaultHistoryMessages = null;
storage = null;
system_instruction = null;
/**
* @constructor
* @param {AgentOptions} agentOptions - Options for configuring the agent, including the model, system instruction, initial messages template, etc.
* @param {ClientOptions} [options] - Optional OpenAI client options.
* @throws {ValidationError} If the model is not specified in the agent options.
*/
constructor(agentOptions, options) {
OpenAIAgent.validateEnvironment();
super(options);
if (!agentOptions.model) {
throw new ValidationError('Model is required to initialize the agent instance');
}
if (agentOptions.system_instruction)
this.system_instruction = agentOptions.system_instruction;
delete agentOptions.system_instruction;
if (agentOptions.messages)
this.defaultHistoryMessages = agentOptions.messages;
delete agentOptions.messages;
this.completionParams = agentOptions;
}
get model() {
return this.completionParams.model;
}
set model(value) {
this.completionParams.model = value;
}
get temperature() {
return this.completionParams.temperature;
}
set temperature(value) {
this.completionParams.temperature = value;
}
get top_p() {
return this.completionParams.top_p;
}
set top_p(value) {
this.completionParams.top_p = value;
}
get max_completion_tokens() {
return this.completionParams.max_completion_tokens;
}
set max_completion_tokens(value) {
this.completionParams.max_completion_tokens = value;
}
get max_tokens() {
return this.completionParams.max_tokens;
}
set max_tokens(value) {
this.completionParams.max_tokens = value;
}
get n() {
return this.completionParams.n;
}
set n(value) {
this.completionParams.n = value;
}
get frequency_penalty() {
return this.completionParams.frequency_penalty;
}
set frequency_penalty(value) {
this.completionParams.frequency_penalty = value;
}
get presence_penalty() {
return this.completionParams.presence_penalty;
}
set presence_penalty(value) {
this.completionParams.presence_penalty = value;
}
get tool_choice() {
return this.completionParams.tool_choice;
}
set tool_choice(value) {
this.completionParams.tool_choice = value;
}
get parallel_tool_calls() {
return this.completionParams.parallel_tool_calls;
}
set parallel_tool_calls(value) {
this.completionParams.parallel_tool_calls = value;
}
get audioParams() {
return this.completionParams.audio;
}
set audioParams(value) {
this.completionParams.audio = value;
}
get response_format() {
return this.completionParams.response_format;
}
set response_format(value) {
this.completionParams.response_format = value;
}
get logit_bias() {
return this.completionParams.logit_bias;
}
set logit_bias(value) {
this.completionParams.logit_bias = value;
}
get logprobs() {
return this.completionParams.logprobs;
}
set logprobs(value) {
this.completionParams.logprobs = value;
}
get top_logprobs() {
return this.completionParams.top_logprobs;
}
set top_logprobs(value) {
this.completionParams.top_logprobs = value;
}
get metadata() {
return this.completionParams.metadata;
}
set metadata(value) {
this.completionParams.metadata = value;
}
get stop() {
return this.completionParams.stop;
}
set stop(value) {
this.completionParams.stop = value;
}
get modalities() {
return this.completionParams.modalities;
}
set modalities(value) {
this.completionParams.modalities = value;
}
get prediction() {
return this.completionParams.prediction;
}
set prediction(value) {
this.completionParams.prediction = value;
}
get seed() {
return this.completionParams.seed;
}
set seed(value) {
this.completionParams.seed = value;
}
get service_tier() {
return this.completionParams.service_tier;
}
set service_tier(value) {
this.completionParams.service_tier = value;
}
get store() {
return this.completionParams.store;
}
set store(value) {
this.completionParams.store = value;
}
/**
* Validates that required environment variables are set.
*/
static validateEnvironment() {
const missingVars = OpenAIAgent.REQUIRED_ENV_VARS.filter((varName) => !process.env[varName]);
if (missingVars.length > 0) {
throw new ValidationError(`Missing required environment variables: ${missingVars.join(', ')}`);
}
}
/**
* Determines the system instruction message to use based on default and custom instructions.
*/
handleSystemInstructionMessage(defaultInstruction, customInstruction) {
const systemInstructionMessage = {
role: 'system',
content: '',
};
if (defaultInstruction && !customInstruction) {
systemInstructionMessage.content = defaultInstruction;
}
else if (customInstruction) {
systemInstructionMessage.content = customInstruction;
}
return systemInstructionMessage;
}
/**
* Retrieves context messages from the default history and/or from persistent storage.
*/
async handleContextMessages(queryParams, historyOptions) {
const userId = queryParams.user ? queryParams.user : 'default';
let contextMessages = [];
if (this.defaultHistoryMessages) {
contextMessages = this.defaultHistoryMessages;
}
if (this.storage) {
const storedMessages = await this.storage.getChatHistory(userId, historyOptions);
contextMessages.push(...storedMessages);
}
console.log('Context length:', contextMessages.length);
return contextMessages;
}
/**
* Executes the functions called by the model and returns their responses.
*/
async callChosenFunctions(responseMessage, functions) {
if (!responseMessage.tool_calls?.length) {
throw new Error('No tool calls found in the response message');
}
const toolMessages = [];
for (const tool of responseMessage.tool_calls) {
const { id, function: { name, arguments: args }, } = tool;
try {
const currentFunction = functions[name];
if (!currentFunction) {
throw new Error(`Function '${name}' not found`);
}
let parsedArgs;
try {
parsedArgs = JSON.parse(args);
}
catch (error) {
console.error(error);
throw new Error(`Invalid arguments format for function '${name}': ${args}`);
}
const functionResponse = await Promise.resolve(currentFunction(parsedArgs));
if (functionResponse === undefined) {
throw new Error(`Function '${name}' returned no response`);
}
toolMessages.push({
tool_call_id: id,
role: 'tool',
content: JSON.stringify(functionResponse),
});
}
catch (error) {
toolMessages.push({
tool_call_id: id,
role: 'tool',
content: JSON.stringify({
error: error instanceof Error
? error.message
: 'Unknown error',
}),
});
throw new FunctionCallError(name, error instanceof Error ? error.message : 'Unknown error');
}
}
return toolMessages;
}
/**
* Handles the process of calling tools based on the model's response
* and making a subsequent API call with the tool responses.
*/
async handleToolCompletion(toolCompletionOpts) {
const { response, queryParams, newMessages, toolFunctions, historyOptions, } = toolCompletionOpts;
if (!queryParams?.messages?.length)
queryParams.messages = [];
const responseMessage = response.choices[0].message;
queryParams.messages.push(responseMessage);
try {
const toolMessages = await this.callChosenFunctions(responseMessage, toolFunctions);
queryParams.messages.push(...toolMessages);
newMessages.push(...toolMessages);
const secondResponse = await this.chat.completions.create(queryParams);
const secondResponseMessage = secondResponse.choices[0].message;
newMessages.push(secondResponseMessage);
if (this.storage) {
await this.storage.saveChatHistory(newMessages, queryParams.user, historyOptions);
}
const responses = handleNResponses(secondResponse, queryParams);
return {
choices: responses,
total_usage: getCompletionsUsage(response, secondResponse),
completion_messages: newMessages,
completions: [response, secondResponse],
};
}
catch (error) {
if (error instanceof FunctionCallError)
throw error;
throw new ToolCompletionError(queryParams, error instanceof Error ? error : undefined);
}
}
/**
* Creates a chat completion, handles tool calls (if any), and manages conversation history.
*
* @param {string} message - The user's message.
* @param {CreateChatCompletionOptions} [completionOptions] - Options for the chat completion, including custom parameters, tool choices, and history management.
* @returns {Promise<CompletionResult>} A promise that resolves to the chat completion result.
* @throws {ChatCompletionError | StorageError | RedisConnectionError | RedisKeyValidationError | MessageValidationError | DirectoryAccessError | FileReadError | FileImportError | InvalidToolError | ToolNotFoundError | ToolCompletionError | FunctionCallError | ValidationError} If an error occurs during the completion process.
*/
async createChatCompletion(message, completionOptions = {}) {
const queryParams = {
...this.completionParams,
...completionOptions.custom_params,
};
const historyOptions = {
...this.storage?.historyOptions,
...completionOptions.history,
};
let storedMessages = [];
try {
storedMessages = await this.handleContextMessages(queryParams, historyOptions);
if (this.storage &&
completionOptions.tool_choices &&
historyOptions?.appended_messages)
storedMessages =
this.storage.removeOrphanedToolMessages(storedMessages);
const systemImstructionMessage = this.handleSystemInstructionMessage(this.system_instruction, completionOptions.system_instruction);
if (systemImstructionMessage.content) {
// Overwrites the default instruction if there is a new instruction in the current request
if (storedMessages[0]?.role === 'system')
storedMessages.shift();
storedMessages.unshift(systemImstructionMessage);
}
const newMessages = [
{ role: 'user', content: message },
];
storedMessages.push(...newMessages);
queryParams.messages = storedMessages;
let toolFunctions;
if (completionOptions.tool_choices?.length) {
const toolChoices = await importToolFunctions(completionOptions.tool_choices);
queryParams.tools = toolChoices.toolChoices;
toolFunctions = toolChoices.toolFunctions;
}
const response = await this.chat.completions.create(queryParams);
const responseMessage = response.choices[0].message;
newMessages.push(responseMessage);
if (responseMessage.tool_calls && toolFunctions) {
return await this.handleToolCompletion({
response,
queryParams,
newMessages,
toolFunctions,
historyOptions,
});
}
else {
if (this.storage) {
await this.storage.saveChatHistory(newMessages, queryParams.user, historyOptions);
}
const responses = handleNResponses(response, queryParams);
return {
choices: responses,
total_usage: getCompletionsUsage(response),
completion_messages: newMessages,
completions: [response],
};
}
}
catch (error) {
if (error instanceof StorageError ||
error instanceof RedisConnectionError ||
error instanceof RedisKeyValidationError ||
error instanceof MessageValidationError ||
error instanceof DirectoryAccessError ||
error instanceof FileReadError ||
error instanceof FileImportError ||
error instanceof InvalidToolError ||
error instanceof ToolNotFoundError ||
error instanceof ToolCompletionError ||
error instanceof FunctionCallError ||
error instanceof ValidationError)
throw error;
throw new ChatCompletionError(queryParams, error instanceof Error ? error : undefined);
}
}
/**
* Loads tool functions from the specified directory.
*
* @param {string} toolsDirPath - The path to the directory containing the tool functions.
* @returns {Promise<boolean>} A promise that resolves to true if the tools are loaded successfully.
* @throws {ValidationError} If the tools directory path is not provided or invalid.
*/
async loadToolFuctions(toolsDirPath) {
if (!toolsDirPath)
throw new ValidationError('Tools directory path required.');
await loadToolsDirFunctions(toolsDirPath);
ToolsRegistry.toolsDirPath = toolsDirPath;
return true;
}
/**
* Sets up the storage using a Redis client.
*
* @param {RedisClientType} client - The Redis client instance.
* @param {{ history: HistoryOptions }} [options] - Options for configuring history storage.
* @returns {Promise<boolean>} A promise that resolves to true if the storage is set up successfully.
* @throws {RedisConnectionError} If the Redis client is not provided.
*/
async setStorage(client, options) {
if (!client)
throw new RedisConnectionError('Instance of Redis is required.');
this.storage = new AgentStorage(client);
if (options?.history)
this.storage.historyOptions = options.history;
return true;
}
/**
* Deletes the chat history for a given user.
*
* @param {string} userId - The ID of the user whose history should be deleted.
* @returns {Promise<boolean>} A promise that resolves to true if the history is deleted successfully.
* @throws {RedisConnectionError} If the storage is not initialized.
*/
async deleteChatHistory(userId) {
if (!this.storage) {
throw new RedisConnectionError('Agent storage is not initalized.');
}
await this.storage.deleteHistory(userId);
return true;
}
/**
* Retrieves the chat history for a given user.
*
* @param {string} userId - The ID of the user whose history should be retrieved.
* @param {HistoryOptions} [options] - Options for retrieving history.
* @returns {Promise<ChatCompletionMessageParam[]>} A promise that resolves to an array of chat messages.
* @throws {RedisConnectionError} If the storage is not initialized.
*/
async getChatHistory(userId) {
if (!this.storage)
throw new RedisConnectionError('Agent storage is not initialized');
const messages = await this.storage.getChatHistory(userId);
return messages;
}
}