UNPKG

openai-agents

Version:

A TypeScript library extending the OpenAI Node.js SDK for building highly customizable agents and simplifying 'function calling'. Easily create and manage tools to extend LLM capabilities.

1,017 lines (1,010 loc) 39 kB
import OpenAI__default from 'openai'; export * from 'openai'; import path from 'path'; import * as fs from 'fs/promises'; // Base Error class BaseError extends Error { cause; constructor(message, cause) { super(message); this.cause = cause; this.name = this.constructor.name; Error.captureStackTrace(this, this.constructor); } } // Validation Errors class ValidationError extends BaseError { constructor(message, cause) { super(message, cause); } } class MessageValidationError extends ValidationError { constructor(message, cause) { super(message, cause); } } // File System Errors class FileSystemError extends BaseError { path; constructor(message, path, cause) { super(`${message}: ${path}`, cause); this.path = path; } } class DirectoryAccessError extends FileSystemError { constructor(dirPath, cause) { super('Unable to access directory', dirPath, cause); } } class FileReadError extends FileSystemError { constructor(filePath, cause) { super('Error reading file', filePath, cause); } } class FileImportError extends FileSystemError { constructor(filePath, cause) { super('Error importing file', filePath, cause); } } // Tool Errors class ToolError extends BaseError { toolName; constructor(message, toolName, cause) { super(`${message}: ${toolName}`, cause); this.toolName = toolName; } } class InvalidToolError extends ToolError { constructor(toolName, details, cause) { super(`Invalid tool found: ${details}`, toolName, cause); } } class ToolNotFoundError extends ToolError { constructor(toolName, cause) { super('Tool not found', toolName, cause); } } class FunctionCallError extends ToolError { constructor(functionName, details, cause) { super(`Error calling function: ${details}`, functionName, cause); } } // API Errors class APIError extends BaseError { payload; constructor(message, payload, cause) { super(`${message}. Payload: ${JSON.stringify(payload)}`, cause); this.payload = payload; } } class ToolCompletionError extends APIError { constructor(payload, cause) { super('Tool completion failed', payload, cause); } } class ChatCompletionError extends APIError { constructor(payload, cause) { super('Chat completion failed', payload, cause); } } // Storage Errors class StorageError extends BaseError { constructor(message, cause) { super(message, cause); } } class RedisError extends StorageError { constructor(message, cause) { super(`Redis error: ${message}`, cause); } } class RedisConnectionError extends RedisError { constructor(message, cause) { super(message, cause); } } class RedisKeyValidationError extends RedisError { constructor(message, cause) { super(`Key validation failed: ${message}`, cause); } } /** * @class ToolsRegistry * @description Singleton class for managing the tools registry. Holds the currently loaded tools. */ class ToolsRegistry { static instance = null; static toolsDirPath = null; /** * Gets the current instance of the tools registry. */ static getInstance() { return ToolsRegistry.instance; } /** * Sets the instance of the tools registry. */ static setInstance(tools) { ToolsRegistry.instance = tools; } } /** * Validates the function name, ensuring it meets OpenAI's requirements. */ const validateFunctionName = (name) => { if (!name || typeof name !== 'string') { throw new InvalidToolError(name, 'Function name must be a non-empty string'); } if (name.length > 64) { throw new InvalidToolError(name, 'Function name must not exceed 64 characters'); } if (!/^[a-zA-Z0-9_-]+$/.test(name)) { throw new InvalidToolError(name, 'Function name must contain only alphanumeric characters, underscores, and hyphens'); } }; /** * Validates the function parameters, ensuring they are a non-null object. */ const validateFunctionParameters = (params, name) => { if (!params || typeof params !== 'object') { throw new InvalidToolError(name, 'Function parameters must be a non-null object'); } }; /** * Validates the function definition, * checking name, description, parameters, and strict flag. */ const validateFunctionDefinition = (func) => { const { name, description, parameters, strict } = func; if (!func || typeof func !== 'object') { throw new InvalidToolError(name, 'Function definition must be a non-null object'); } validateFunctionName(name); if (description !== undefined && typeof description !== 'string') { throw new InvalidToolError(name, 'Function description must be a string when provided'); } if (parameters !== undefined) { validateFunctionParameters(parameters, name); } if (strict !== undefined && strict !== null && typeof strict !== 'boolean') { throw new InvalidToolError(name, 'Function strict flag must be a boolean when provided'); } }; /** * Validates a chat completion tool definition, * ensuring it has the correct type and a valid function definition. */ const validateChatCompletionTool = (tool) => { const { function: functionDefinition, function: { name }, type, } = tool; if (!tool || typeof tool !== 'object') { throw new InvalidToolError(name, 'Chat completion tool must be a non-null object'); } if (type !== 'function') { throw new InvalidToolError(name, 'Chat completion tool type must be "function"'); } validateFunctionDefinition(functionDefinition); }; /** * Validates the configuration of tools, ensuring that all defined tools have corresponding implementations. */ const validateToolConfiguration = (fnDefinitions, functions) => { for (const def of fnDefinitions) { const functionName = def.function.name; if (!functions[functionName]) { throw new ToolNotFoundError(`Missing function implementation for tool: ${functionName}`); } } }; /** * Loads tool files (both definitions and implementations) from a specified directory. * * @param {string} dirPath - The path to the directory containing tool files. * @returns {Promise<AgentTools>} A promise that resolves to the loaded agent tools. * @throws {DirectoryAccessError | FileReadError | FileImportError | InvalidToolError | ToolNotFoundError} If an error occurs during loading. */ const loadToolsDirFunctions = async (dirPath) => { try { // Validate directory access try { await fs.access(dirPath); } catch (error) { throw new DirectoryAccessError(dirPath, error instanceof Error ? error : undefined); } // Read directory contents let files; try { files = await fs.readdir(dirPath); } catch (error) { throw new FileReadError(dirPath, error instanceof Error ? error : undefined); } const toolDefinitions = []; const toolFunctions = {}; // Process each file for (const file of files) { if (!file.endsWith('.js') && !file.endsWith('.ts')) continue; const fullPath = path.join(dirPath, file); // Validate file status try { const stat = await fs.stat(fullPath); if (!stat.isFile()) continue; } catch (error) { throw new FileReadError(fullPath, error instanceof Error ? error : undefined); } // Import file contents let fileFunctions; try { fileFunctions = await import(fullPath); } catch (error) { throw new FileImportError(fullPath, error instanceof Error ? error : undefined); } // Process functions const funcs = fileFunctions.default || fileFunctions; for (const [fnName, fn] of Object.entries(funcs)) { try { if (typeof fn === 'function') { toolFunctions[fnName] = fn; } else { // Validate as tool definition validateChatCompletionTool(fn); toolDefinitions.push(fn); } } catch (error) { if (error instanceof InvalidToolError) throw error; throw new InvalidToolError(fnName, `Unexpected error validating tool: ${error instanceof Error ? error.message : 'Unknown error'}`); } } } // Validate final configuration validateToolConfiguration(toolDefinitions, toolFunctions); const tools = { toolDefinitions, toolFunctions }; ToolsRegistry.setInstance(tools); return tools; } catch (error) { if (error instanceof DirectoryAccessError || error instanceof FileReadError || error instanceof FileImportError || error instanceof InvalidToolError || error instanceof ToolNotFoundError) throw error; throw new Error(`Unexpected error loading tools: ${error instanceof Error ? error.message : 'Unknown error'}`); } }; /** * Imports and returns specific tool functions based on their names. * Loads tools from the directory if they haven't been loaded yet. * * @param {string[]} toolNames - An array of tool names to import. * @returns {Promise<ToolChoices>} A promise that resolves to the imported tool functions and choices. * @throws {ValidationError | ToolNotFoundError | InvalidToolError} If the tools directory path is not set or if any requested tools are missing. */ const importToolFunctions = async (toolNames) => { try { if (!ToolsRegistry.toolsDirPath) { throw new ValidationError('Tools directory path not set. Call loadToolsDirFunctions with your tools directory path first.'); } const tools = ToolsRegistry.getInstance() ?? (await loadToolsDirFunctions(ToolsRegistry.toolsDirPath)); const toolChoices = toolNames .map((toolName) => tools.toolDefinitions.find((tool) => tool.function.name === toolName)) .filter((tool) => tool !== undefined); const missingTools = toolNames.filter((name) => !toolChoices.some((tool) => tool.function.name === name)); if (missingTools.length > 0) { throw new ToolNotFoundError(`The following tools were not found: ${missingTools.join(', ')}`); } return { toolFunctions: tools.toolFunctions, toolChoices, }; } catch (error) { if (error instanceof DirectoryAccessError || error instanceof FileReadError || error instanceof FileImportError || error instanceof InvalidToolError || error instanceof ValidationError || error instanceof ToolNotFoundError) throw error; throw new Error(`Failed to import tool functions: ${error instanceof Error ? error.message : 'Unknown error'}`); } }; /** * Calculates the sum of prompt tokens, completion tokens, and total tokens from multiple `CompletionUsage` objects. */ const getTokensSum = (...usages) => { const initialUsage = { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, }; return usages.reduce((accumulator, currentUsage) => { return { prompt_tokens: accumulator.prompt_tokens + (currentUsage?.prompt_tokens ?? 0), completion_tokens: accumulator.completion_tokens + (currentUsage?.completion_tokens ?? 0), total_tokens: accumulator.total_tokens + (currentUsage?.total_tokens ?? 0), }; }, initialUsage); }; /** * Extracts and aggregates the `usage` information from multiple `ChatCompletion` objects. * * @param {...ChatCompletion} completions - One or more `ChatCompletion` objects. * @returns {CompletionUsage} A `CompletionUsage` object representing the aggregated usage data. * Returns an object with all properties set to 0 if no completions are provided or if none of them have a usage property. */ const getCompletionsUsage = (...completions) => { const usages = []; for (const completion of completions) { if (completion.usage) { usages.push(completion.usage); } } return getTokensSum(...usages); }; const handleNResponses = (response, queryParams) => { let responses = []; const responseMessage = response.choices[0].message; if (queryParams.n) { for (const choice of response.choices) { if (choice.message.content) responses.push(choice.message.content); } } else { responses = [responseMessage.content ?? 'Response not received.']; } return responses; }; const CONFIG = { USER_ID_MAX_LENGTH: 64, DEFAULT_CHAT_MAX_LENGTH: 100, KEY_PREFIX: 'chat:', DEFAULT_USER: 'default', }; /** * @class AgentStorage * @description Manages chat history and session metadata persistence using Redis. */ class AgentStorage { redisClient; historyOptions; constructor(client) { if (!client) { throw new ValidationError('Redis client must be provided'); } this.redisClient = client; } /** * Validates the user ID, returning a 'default' value if undefined * and throwing errors for invalid formats. */ validateUserId(userId) { if (!userId) return CONFIG.DEFAULT_USER; if (typeof userId !== 'string') { throw new RedisKeyValidationError('User ID must be a string'); } if (userId.length > CONFIG.USER_ID_MAX_LENGTH) { throw new RedisKeyValidationError(`User ID exceeds maximum length of ${CONFIG.USER_ID_MAX_LENGTH} characters`); } if (!/^[a-zA-Z0-9_-]+$/.test(userId)) { throw new RedisKeyValidationError('User ID contains invalid characters. Only alphanumeric, underscore, and hyphen are allowed'); } return userId; } /** * Generates the Redis key for a given user ID. */ getRedisKey(userId) { return `${CONFIG.KEY_PREFIX}${this.validateUserId(userId)}`; } /** * Filters out tool-related messages from the chat history. * * @param {ChatCompletionMessageParam[]} messages - The chat history messages. * @returns {ChatCompletionMessageParam[]} The filtered messages. */ removeToolMessages(messages) { return messages.filter((message) => message.role === 'user' || (message.role === 'assistant' && !message.tool_calls)); } /** * Removes tool messages that don't have a corresponding assistant or tool call ID. * * @param {ChatCompletionMessageParam[]} messages - The chat history messages. * @returns {ChatCompletionMessageParam[]} The filtered messages. */ removeOrphanedToolMessages(messages) { const toolCallIds = new Set(); const assistantCallIds = new Set(); messages.forEach((message) => { if ('tool_call_id' in message) { toolCallIds.add(message.tool_call_id); } else if ('tool_calls' in message) { if (message.tool_calls) message.tool_calls.forEach((toolCall) => { assistantCallIds.add(toolCall.id); }); } }); return messages.filter((message) => { if ('tool_call_id' in message) { return assistantCallIds.has(message.tool_call_id); } else if ('tool_calls' in message) { if (message.tool_calls) { message.tool_calls = message.tool_calls.filter((toolCall) => toolCallIds.has(toolCall.id)); if (!message.tool_calls.length) return false; } } return true; }); } filterMessages(messages, options) { let filteredMessages = [...messages]; if (filteredMessages[0].role === 'system') filteredMessages.shift(); if (options.remove_tool_messages) { filteredMessages = this.removeToolMessages(filteredMessages); } return filteredMessages; } async calculateHistoryLength(redisKey, messages, options) { let savedHistoryLength = 0; try { savedHistoryLength = await this.redisClient.lLen(redisKey); } catch (error) { throw new StorageError('Error getting history length', error instanceof Error ? error : undefined); } if (options.max_length && savedHistoryLength + messages.length > options.max_length) { const length = messages.length > options.max_length ? 0 : options.max_length - messages.length - 1; return length; } return CONFIG.DEFAULT_CHAT_MAX_LENGTH - messages.length - 1; } async saveChatHistory(messages, userId, options = {}) { try { const redisKey = this.getRedisKey(userId); const filteredMessages = this.filterMessages(messages, options); const length = await this.calculateHistoryLength(redisKey, filteredMessages, options); const multi = this.redisClient.multi(); multi.lTrim(redisKey, 0, length); try { for (const message of filteredMessages) { multi.lPush(redisKey, JSON.stringify(message)); } } catch (error) { throw new MessageValidationError(`Invalid message in storage: ${error instanceof Error ? error.message : 'Unknown error'}`, error instanceof Error ? error : undefined); } if (options.ttl) multi.expire(redisKey, options.ttl); await multi.exec(); } catch (error) { if (error instanceof MessageValidationError || error instanceof RedisKeyValidationError) throw error; throw new StorageError(`Failed to save chat history: ${error instanceof Error ? error.message : 'Unknown error'}`, error instanceof Error ? error : undefined); } } /** * Retrieves the chat history from Redis. * * @param {string} userId - The user ID. * @param {HistoryOptions} options - Options for retrieving the history. * @returns {Promise<ChatCompletionMessageParam[]>} The retrieved chat history. * @throws {StorageError} If retrieving the chat history fails. * @throws {MessageValidationError} If a message is invalid. * @throws {RedisKeyValidationError} If the user ID is invalid. */ async getChatHistory(userId, options = {}) { try { const key = this.getRedisKey(userId); const { appended_messages, remove_tool_messages, send_tool_messages, max_length, } = options; if (appended_messages === 0) return []; const messages = await this.redisClient.lRange(key, 0, appended_messages ? appended_messages - 1 : -1); if (!messages.length) return []; let parsedMessages = []; try { parsedMessages = messages .map((message) => { return JSON.parse(message); }) .reverse(); } catch (error) { throw new MessageValidationError(`Invalid message in storage: ${error instanceof Error ? error.message : 'Unknown error'}`, error instanceof Error ? error : undefined); } if (remove_tool_messages || send_tool_messages === false) return this.removeToolMessages(parsedMessages); if (max_length) return this.removeOrphanedToolMessages(parsedMessages); return parsedMessages; } catch (error) { if (error instanceof MessageValidationError || error instanceof RedisKeyValidationError) throw error; throw new StorageError(`Failed to retrieve stored messages: ${error instanceof Error ? error.message : 'Unknown error'}`, error instanceof Error ? error : undefined); } } /** * Deletes the chat history from Redis for a given user ID. * * @param {string} userId - The user ID. * @returns {Promise<number>} * @throws {StorageError} If deleting the chat history fails. * @throws {RedisKeyValidationError} If the user ID is invalid. */ async deleteHistory(userId) { if (!userId.trim()) { throw new ValidationError('User ID is required'); } try { const key = this.getRedisKey(userId); const result = await this.redisClient.del(key); return result > 0; } catch (error) { if (error instanceof RedisKeyValidationError) { throw error; } throw new StorageError(`Failed to delete chat history: ${error instanceof Error ? error.message : 'Unknown error occurred'}`, error instanceof Error ? error : undefined); } } } /** * @class OpenAIAgent * @description Extends the OpenAI API client to manage chat completions, tool interactions, and persistent storage of conversation history. Provides methods for creating chat completions, managing tools, and interacting with a Redis storage. */ class OpenAIAgent extends OpenAI__default { static REQUIRED_ENV_VARS = ['OPENAI_API_KEY']; completionParams; defaultHistoryMessages = null; storage = null; system_instruction = null; /** * @constructor * @param {AgentOptions} agentOptions - Options for configuring the agent, including the model, system instruction, initial messages template, etc. * @param {ClientOptions} [options] - Optional OpenAI client options. * @throws {ValidationError} If the model is not specified in the agent options. */ constructor(agentOptions, options) { OpenAIAgent.validateEnvironment(); super(options); if (!agentOptions.model) { throw new ValidationError('Model is required to initialize the agent instance'); } if (agentOptions.system_instruction) this.system_instruction = agentOptions.system_instruction; delete agentOptions.system_instruction; if (agentOptions.messages) this.defaultHistoryMessages = agentOptions.messages; delete agentOptions.messages; this.completionParams = agentOptions; } get model() { return this.completionParams.model; } set model(value) { this.completionParams.model = value; } get temperature() { return this.completionParams.temperature; } set temperature(value) { this.completionParams.temperature = value; } get top_p() { return this.completionParams.top_p; } set top_p(value) { this.completionParams.top_p = value; } get max_completion_tokens() { return this.completionParams.max_completion_tokens; } set max_completion_tokens(value) { this.completionParams.max_completion_tokens = value; } get max_tokens() { return this.completionParams.max_tokens; } set max_tokens(value) { this.completionParams.max_tokens = value; } get n() { return this.completionParams.n; } set n(value) { this.completionParams.n = value; } get frequency_penalty() { return this.completionParams.frequency_penalty; } set frequency_penalty(value) { this.completionParams.frequency_penalty = value; } get presence_penalty() { return this.completionParams.presence_penalty; } set presence_penalty(value) { this.completionParams.presence_penalty = value; } get tool_choice() { return this.completionParams.tool_choice; } set tool_choice(value) { this.completionParams.tool_choice = value; } get parallel_tool_calls() { return this.completionParams.parallel_tool_calls; } set parallel_tool_calls(value) { this.completionParams.parallel_tool_calls = value; } get audioParams() { return this.completionParams.audio; } set audioParams(value) { this.completionParams.audio = value; } get response_format() { return this.completionParams.response_format; } set response_format(value) { this.completionParams.response_format = value; } get logit_bias() { return this.completionParams.logit_bias; } set logit_bias(value) { this.completionParams.logit_bias = value; } get logprobs() { return this.completionParams.logprobs; } set logprobs(value) { this.completionParams.logprobs = value; } get top_logprobs() { return this.completionParams.top_logprobs; } set top_logprobs(value) { this.completionParams.top_logprobs = value; } get metadata() { return this.completionParams.metadata; } set metadata(value) { this.completionParams.metadata = value; } get stop() { return this.completionParams.stop; } set stop(value) { this.completionParams.stop = value; } get modalities() { return this.completionParams.modalities; } set modalities(value) { this.completionParams.modalities = value; } get prediction() { return this.completionParams.prediction; } set prediction(value) { this.completionParams.prediction = value; } get seed() { return this.completionParams.seed; } set seed(value) { this.completionParams.seed = value; } get service_tier() { return this.completionParams.service_tier; } set service_tier(value) { this.completionParams.service_tier = value; } get store() { return this.completionParams.store; } set store(value) { this.completionParams.store = value; } /** * Validates that required environment variables are set. */ static validateEnvironment() { const missingVars = OpenAIAgent.REQUIRED_ENV_VARS.filter((varName) => !process.env[varName]); if (missingVars.length > 0) { throw new ValidationError(`Missing required environment variables: ${missingVars.join(', ')}`); } } /** * Determines the system instruction message to use based on default and custom instructions. */ handleSystemInstructionMessage(defaultInstruction, customInstruction) { const systemInstructionMessage = { role: 'system', content: '', }; if (defaultInstruction && !customInstruction) { systemInstructionMessage.content = defaultInstruction; } else if (customInstruction) { systemInstructionMessage.content = customInstruction; } return systemInstructionMessage; } /** * Retrieves context messages from the default history and/or from persistent storage. */ async handleContextMessages(queryParams, historyOptions) { const userId = queryParams.user ? queryParams.user : 'default'; let contextMessages = []; if (this.defaultHistoryMessages) { contextMessages = this.defaultHistoryMessages; } if (this.storage) { const storedMessages = await this.storage.getChatHistory(userId, historyOptions); contextMessages.push(...storedMessages); } console.log('Context length:', contextMessages.length); return contextMessages; } /** * Executes the functions called by the model and returns their responses. */ async callChosenFunctions(responseMessage, functions) { if (!responseMessage.tool_calls?.length) { throw new Error('No tool calls found in the response message'); } const toolMessages = []; for (const tool of responseMessage.tool_calls) { const { id, function: { name, arguments: args }, } = tool; try { const currentFunction = functions[name]; if (!currentFunction) { throw new Error(`Function '${name}' not found`); } let parsedArgs; try { parsedArgs = JSON.parse(args); } catch (error) { console.error(error); throw new Error(`Invalid arguments format for function '${name}': ${args}`); } const functionResponse = await Promise.resolve(currentFunction(parsedArgs)); if (functionResponse === undefined) { throw new Error(`Function '${name}' returned no response`); } toolMessages.push({ tool_call_id: id, role: 'tool', content: JSON.stringify(functionResponse), }); } catch (error) { toolMessages.push({ tool_call_id: id, role: 'tool', content: JSON.stringify({ error: error instanceof Error ? error.message : 'Unknown error', }), }); throw new FunctionCallError(name, error instanceof Error ? error.message : 'Unknown error'); } } return toolMessages; } /** * Handles the process of calling tools based on the model's response * and making a subsequent API call with the tool responses. */ async handleToolCompletion(toolCompletionOpts) { const { response, queryParams, newMessages, toolFunctions, historyOptions, } = toolCompletionOpts; if (!queryParams?.messages?.length) queryParams.messages = []; const responseMessage = response.choices[0].message; queryParams.messages.push(responseMessage); try { const toolMessages = await this.callChosenFunctions(responseMessage, toolFunctions); queryParams.messages.push(...toolMessages); newMessages.push(...toolMessages); const secondResponse = await this.chat.completions.create(queryParams); const secondResponseMessage = secondResponse.choices[0].message; newMessages.push(secondResponseMessage); if (this.storage) { await this.storage.saveChatHistory(newMessages, queryParams.user, historyOptions); } const responses = handleNResponses(secondResponse, queryParams); return { choices: responses, total_usage: getCompletionsUsage(response, secondResponse), completion_messages: newMessages, completions: [response, secondResponse], }; } catch (error) { if (error instanceof FunctionCallError) throw error; throw new ToolCompletionError(queryParams, error instanceof Error ? error : undefined); } } /** * Creates a chat completion, handles tool calls (if any), and manages conversation history. * * @param {string} message - The user's message. * @param {CreateChatCompletionOptions} [completionOptions] - Options for the chat completion, including custom parameters, tool choices, and history management. * @returns {Promise<CompletionResult>} A promise that resolves to the chat completion result. * @throws {ChatCompletionError | StorageError | RedisConnectionError | RedisKeyValidationError | MessageValidationError | DirectoryAccessError | FileReadError | FileImportError | InvalidToolError | ToolNotFoundError | ToolCompletionError | FunctionCallError | ValidationError} If an error occurs during the completion process. */ async createChatCompletion(message, completionOptions = {}) { const queryParams = { ...this.completionParams, ...completionOptions.custom_params, }; const historyOptions = { ...this.storage?.historyOptions, ...completionOptions.history, }; let storedMessages = []; try { storedMessages = await this.handleContextMessages(queryParams, historyOptions); if (this.storage && completionOptions.tool_choices && historyOptions?.appended_messages) storedMessages = this.storage.removeOrphanedToolMessages(storedMessages); const systemImstructionMessage = this.handleSystemInstructionMessage(this.system_instruction, completionOptions.system_instruction); if (systemImstructionMessage.content) { // Overwrites the default instruction if there is a new instruction in the current request if (storedMessages[0]?.role === 'system') storedMessages.shift(); storedMessages.unshift(systemImstructionMessage); } const newMessages = [ { role: 'user', content: message }, ]; storedMessages.push(...newMessages); queryParams.messages = storedMessages; let toolFunctions; if (completionOptions.tool_choices?.length) { const toolChoices = await importToolFunctions(completionOptions.tool_choices); queryParams.tools = toolChoices.toolChoices; toolFunctions = toolChoices.toolFunctions; } const response = await this.chat.completions.create(queryParams); const responseMessage = response.choices[0].message; newMessages.push(responseMessage); if (responseMessage.tool_calls && toolFunctions) { return await this.handleToolCompletion({ response, queryParams, newMessages, toolFunctions, historyOptions, }); } else { if (this.storage) { await this.storage.saveChatHistory(newMessages, queryParams.user, historyOptions); } const responses = handleNResponses(response, queryParams); return { choices: responses, total_usage: getCompletionsUsage(response), completion_messages: newMessages, completions: [response], }; } } catch (error) { if (error instanceof StorageError || error instanceof RedisConnectionError || error instanceof RedisKeyValidationError || error instanceof MessageValidationError || error instanceof DirectoryAccessError || error instanceof FileReadError || error instanceof FileImportError || error instanceof InvalidToolError || error instanceof ToolNotFoundError || error instanceof ToolCompletionError || error instanceof FunctionCallError || error instanceof ValidationError) throw error; throw new ChatCompletionError(queryParams, error instanceof Error ? error : undefined); } } /** * Loads tool functions from the specified directory. * * @param {string} toolsDirPath - The path to the directory containing the tool functions. * @returns {Promise<boolean>} A promise that resolves to true if the tools are loaded successfully. * @throws {ValidationError} If the tools directory path is not provided or invalid. */ async loadToolFuctions(toolsDirPath) { if (!toolsDirPath) throw new ValidationError('Tools directory path required.'); await loadToolsDirFunctions(toolsDirPath); ToolsRegistry.toolsDirPath = toolsDirPath; return true; } /** * Sets up the storage using a Redis client. * * @param {RedisClientType} client - The Redis client instance. * @param {{ history: HistoryOptions }} [options] - Options for configuring history storage. * @returns {Promise<boolean>} A promise that resolves to true if the storage is set up successfully. * @throws {RedisConnectionError} If the Redis client is not provided. */ async setStorage(client, options) { if (!client) throw new RedisConnectionError('Instance of Redis is required.'); this.storage = new AgentStorage(client); if (options?.history) this.storage.historyOptions = options.history; return true; } /** * Deletes the chat history for a given user. * * @param {string} userId - The ID of the user whose history should be deleted. * @returns {Promise<boolean>} A promise that resolves to true if the history is deleted successfully. * @throws {RedisConnectionError} If the storage is not initialized. */ async deleteChatHistory(userId) { if (!this.storage) { throw new RedisConnectionError('Agent storage is not initalized.'); } await this.storage.deleteHistory(userId); return true; } /** * Retrieves the chat history for a given user. * * @param {string} userId - The ID of the user whose history should be retrieved. * @param {HistoryOptions} [options] - Options for retrieving history. * @returns {Promise<ChatCompletionMessageParam[]>} A promise that resolves to an array of chat messages. * @throws {RedisConnectionError} If the storage is not initialized. */ async getChatHistory(userId) { if (!this.storage) throw new RedisConnectionError('Agent storage is not initialized'); const messages = await this.storage.getChatHistory(userId); return messages; } } export { AgentStorage, OpenAIAgent, getCompletionsUsage };