UNPKG

openai-agents

Version:

A TypeScript library extending the OpenAI Node.js SDK for building highly customizable agents and simplifying 'function calling'. Easily create and manage tools to extend LLM capabilities.

435 lines (434 loc) 17.7 kB
import OpenAI from 'openai'; import { ChatCompletionError, DirectoryAccessError, FileImportError, FileReadError, FunctionCallError, InvalidToolError, MessageValidationError, RedisConnectionError, StorageError, ToolCompletionError, ToolNotFoundError, RedisKeyValidationError, ValidationError, } from './errors'; import { importToolFunctions, loadToolsDirFunctions, ToolsRegistry, } from './modules/tools-registry'; import { getCompletionsUsage, handleNResponses } from './utils'; import { AgentStorage } from './storage'; /** * @class OpenAIAgent * @description Extends the OpenAI API client to manage chat completions, tool interactions, and persistent storage of conversation history. Provides methods for creating chat completions, managing tools, and interacting with a Redis storage. */ export class OpenAIAgent extends OpenAI { static REQUIRED_ENV_VARS = ['OPENAI_API_KEY']; completionParams; defaultHistoryMessages = null; storage = null; system_instruction = null; /** * @constructor * @param {AgentOptions} agentOptions - Options for configuring the agent, including the model, system instruction, initial messages template, etc. * @param {ClientOptions} [options] - Optional OpenAI client options. * @throws {ValidationError} If the model is not specified in the agent options. */ constructor(agentOptions, options) { OpenAIAgent.validateEnvironment(); super(options); if (!agentOptions.model) { throw new ValidationError('Model is required to initialize the agent instance'); } if (agentOptions.system_instruction) this.system_instruction = agentOptions.system_instruction; delete agentOptions.system_instruction; if (agentOptions.messages) this.defaultHistoryMessages = agentOptions.messages; delete agentOptions.messages; this.completionParams = agentOptions; } get model() { return this.completionParams.model; } set model(value) { this.completionParams.model = value; } get temperature() { return this.completionParams.temperature; } set temperature(value) { this.completionParams.temperature = value; } get top_p() { return this.completionParams.top_p; } set top_p(value) { this.completionParams.top_p = value; } get max_completion_tokens() { return this.completionParams.max_completion_tokens; } set max_completion_tokens(value) { this.completionParams.max_completion_tokens = value; } get max_tokens() { return this.completionParams.max_tokens; } set max_tokens(value) { this.completionParams.max_tokens = value; } get n() { return this.completionParams.n; } set n(value) { this.completionParams.n = value; } get frequency_penalty() { return this.completionParams.frequency_penalty; } set frequency_penalty(value) { this.completionParams.frequency_penalty = value; } get presence_penalty() { return this.completionParams.presence_penalty; } set presence_penalty(value) { this.completionParams.presence_penalty = value; } get tool_choice() { return this.completionParams.tool_choice; } set tool_choice(value) { this.completionParams.tool_choice = value; } get parallel_tool_calls() { return this.completionParams.parallel_tool_calls; } set parallel_tool_calls(value) { this.completionParams.parallel_tool_calls = value; } get audioParams() { return this.completionParams.audio; } set audioParams(value) { this.completionParams.audio = value; } get response_format() { return this.completionParams.response_format; } set response_format(value) { this.completionParams.response_format = value; } get logit_bias() { return this.completionParams.logit_bias; } set logit_bias(value) { this.completionParams.logit_bias = value; } get logprobs() { return this.completionParams.logprobs; } set logprobs(value) { this.completionParams.logprobs = value; } get top_logprobs() { return this.completionParams.top_logprobs; } set top_logprobs(value) { this.completionParams.top_logprobs = value; } get metadata() { return this.completionParams.metadata; } set metadata(value) { this.completionParams.metadata = value; } get stop() { return this.completionParams.stop; } set stop(value) { this.completionParams.stop = value; } get modalities() { return this.completionParams.modalities; } set modalities(value) { this.completionParams.modalities = value; } get prediction() { return this.completionParams.prediction; } set prediction(value) { this.completionParams.prediction = value; } get seed() { return this.completionParams.seed; } set seed(value) { this.completionParams.seed = value; } get service_tier() { return this.completionParams.service_tier; } set service_tier(value) { this.completionParams.service_tier = value; } get store() { return this.completionParams.store; } set store(value) { this.completionParams.store = value; } /** * Validates that required environment variables are set. */ static validateEnvironment() { const missingVars = OpenAIAgent.REQUIRED_ENV_VARS.filter((varName) => !process.env[varName]); if (missingVars.length > 0) { throw new ValidationError(`Missing required environment variables: ${missingVars.join(', ')}`); } } /** * Determines the system instruction message to use based on default and custom instructions. */ handleSystemInstructionMessage(defaultInstruction, customInstruction) { const systemInstructionMessage = { role: 'system', content: '', }; if (defaultInstruction && !customInstruction) { systemInstructionMessage.content = defaultInstruction; } else if (customInstruction) { systemInstructionMessage.content = customInstruction; } return systemInstructionMessage; } /** * Retrieves context messages from the default history and/or from persistent storage. */ async handleContextMessages(queryParams, historyOptions) { const userId = queryParams.user ? queryParams.user : 'default'; let contextMessages = []; if (this.defaultHistoryMessages) { contextMessages = this.defaultHistoryMessages; } if (this.storage) { const storedMessages = await this.storage.getChatHistory(userId, historyOptions); contextMessages.push(...storedMessages); } console.log('Context length:', contextMessages.length); return contextMessages; } /** * Executes the functions called by the model and returns their responses. */ async callChosenFunctions(responseMessage, functions) { if (!responseMessage.tool_calls?.length) { throw new Error('No tool calls found in the response message'); } const toolMessages = []; for (const tool of responseMessage.tool_calls) { const { id, function: { name, arguments: args }, } = tool; try { const currentFunction = functions[name]; if (!currentFunction) { throw new Error(`Function '${name}' not found`); } let parsedArgs; try { parsedArgs = JSON.parse(args); } catch (error) { console.error(error); throw new Error(`Invalid arguments format for function '${name}': ${args}`); } const functionResponse = await Promise.resolve(currentFunction(parsedArgs)); if (functionResponse === undefined) { throw new Error(`Function '${name}' returned no response`); } toolMessages.push({ tool_call_id: id, role: 'tool', content: JSON.stringify(functionResponse), }); } catch (error) { toolMessages.push({ tool_call_id: id, role: 'tool', content: JSON.stringify({ error: error instanceof Error ? error.message : 'Unknown error', }), }); throw new FunctionCallError(name, error instanceof Error ? error.message : 'Unknown error'); } } return toolMessages; } /** * Handles the process of calling tools based on the model's response * and making a subsequent API call with the tool responses. */ async handleToolCompletion(toolCompletionOpts) { const { response, queryParams, newMessages, toolFunctions, historyOptions, } = toolCompletionOpts; if (!queryParams?.messages?.length) queryParams.messages = []; const responseMessage = response.choices[0].message; queryParams.messages.push(responseMessage); try { const toolMessages = await this.callChosenFunctions(responseMessage, toolFunctions); queryParams.messages.push(...toolMessages); newMessages.push(...toolMessages); const secondResponse = await this.chat.completions.create(queryParams); const secondResponseMessage = secondResponse.choices[0].message; newMessages.push(secondResponseMessage); if (this.storage) { await this.storage.saveChatHistory(newMessages, queryParams.user, historyOptions); } const responses = handleNResponses(secondResponse, queryParams); return { choices: responses, total_usage: getCompletionsUsage(response, secondResponse), completion_messages: newMessages, completions: [response, secondResponse], }; } catch (error) { if (error instanceof FunctionCallError) throw error; throw new ToolCompletionError(queryParams, error instanceof Error ? error : undefined); } } /** * Creates a chat completion, handles tool calls (if any), and manages conversation history. * * @param {string} message - The user's message. * @param {CreateChatCompletionOptions} [completionOptions] - Options for the chat completion, including custom parameters, tool choices, and history management. * @returns {Promise<CompletionResult>} A promise that resolves to the chat completion result. * @throws {ChatCompletionError | StorageError | RedisConnectionError | RedisKeyValidationError | MessageValidationError | DirectoryAccessError | FileReadError | FileImportError | InvalidToolError | ToolNotFoundError | ToolCompletionError | FunctionCallError | ValidationError} If an error occurs during the completion process. */ async createChatCompletion(message, completionOptions = {}) { const queryParams = { ...this.completionParams, ...completionOptions.custom_params, }; const historyOptions = { ...this.storage?.historyOptions, ...completionOptions.history, }; let storedMessages = []; try { storedMessages = await this.handleContextMessages(queryParams, historyOptions); if (this.storage && completionOptions.tool_choices && historyOptions?.appended_messages) storedMessages = this.storage.removeOrphanedToolMessages(storedMessages); const systemImstructionMessage = this.handleSystemInstructionMessage(this.system_instruction, completionOptions.system_instruction); if (systemImstructionMessage.content) { // Overwrites the default instruction if there is a new instruction in the current request if (storedMessages[0]?.role === 'system') storedMessages.shift(); storedMessages.unshift(systemImstructionMessage); } const newMessages = [ { role: 'user', content: message }, ]; storedMessages.push(...newMessages); queryParams.messages = storedMessages; let toolFunctions; if (completionOptions.tool_choices?.length) { const toolChoices = await importToolFunctions(completionOptions.tool_choices); queryParams.tools = toolChoices.toolChoices; toolFunctions = toolChoices.toolFunctions; } const response = await this.chat.completions.create(queryParams); const responseMessage = response.choices[0].message; newMessages.push(responseMessage); if (responseMessage.tool_calls && toolFunctions) { return await this.handleToolCompletion({ response, queryParams, newMessages, toolFunctions, historyOptions, }); } else { if (this.storage) { await this.storage.saveChatHistory(newMessages, queryParams.user, historyOptions); } const responses = handleNResponses(response, queryParams); return { choices: responses, total_usage: getCompletionsUsage(response), completion_messages: newMessages, completions: [response], }; } } catch (error) { if (error instanceof StorageError || error instanceof RedisConnectionError || error instanceof RedisKeyValidationError || error instanceof MessageValidationError || error instanceof DirectoryAccessError || error instanceof FileReadError || error instanceof FileImportError || error instanceof InvalidToolError || error instanceof ToolNotFoundError || error instanceof ToolCompletionError || error instanceof FunctionCallError || error instanceof ValidationError) throw error; throw new ChatCompletionError(queryParams, error instanceof Error ? error : undefined); } } /** * Loads tool functions from the specified directory. * * @param {string} toolsDirPath - The path to the directory containing the tool functions. * @returns {Promise<boolean>} A promise that resolves to true if the tools are loaded successfully. * @throws {ValidationError} If the tools directory path is not provided or invalid. */ async loadToolFuctions(toolsDirPath) { if (!toolsDirPath) throw new ValidationError('Tools directory path required.'); await loadToolsDirFunctions(toolsDirPath); ToolsRegistry.toolsDirPath = toolsDirPath; return true; } /** * Sets up the storage using a Redis client. * * @param {RedisClientType} client - The Redis client instance. * @param {{ history: HistoryOptions }} [options] - Options for configuring history storage. * @returns {Promise<boolean>} A promise that resolves to true if the storage is set up successfully. * @throws {RedisConnectionError} If the Redis client is not provided. */ async setStorage(client, options) { if (!client) throw new RedisConnectionError('Instance of Redis is required.'); this.storage = new AgentStorage(client); if (options?.history) this.storage.historyOptions = options.history; return true; } /** * Deletes the chat history for a given user. * * @param {string} userId - The ID of the user whose history should be deleted. * @returns {Promise<boolean>} A promise that resolves to true if the history is deleted successfully. * @throws {RedisConnectionError} If the storage is not initialized. */ async deleteChatHistory(userId) { if (!this.storage) { throw new RedisConnectionError('Agent storage is not initalized.'); } await this.storage.deleteHistory(userId); return true; } /** * Retrieves the chat history for a given user. * * @param {string} userId - The ID of the user whose history should be retrieved. * @param {HistoryOptions} [options] - Options for retrieving history. * @returns {Promise<ChatCompletionMessageParam[]>} A promise that resolves to an array of chat messages. * @throws {RedisConnectionError} If the storage is not initialized. */ async getChatHistory(userId) { if (!this.storage) throw new RedisConnectionError('Agent storage is not initialized'); const messages = await this.storage.getChatHistory(userId); return messages; } }