openai-agents
Version:
A TypeScript library extending the OpenAI Node.js SDK for building highly customizable agents and simplifying 'function calling'. Easily create and manage tools to extend LLM capabilities.
1,017 lines (1,010 loc) • 39 kB
JavaScript
import OpenAI__default from 'openai';
export * from 'openai';
import path from 'path';
import * as fs from 'fs/promises';
// Base Error
class BaseError extends Error {
cause;
constructor(message, cause) {
super(message);
this.cause = cause;
this.name = this.constructor.name;
Error.captureStackTrace(this, this.constructor);
}
}
// Validation Errors
class ValidationError extends BaseError {
constructor(message, cause) {
super(message, cause);
}
}
class MessageValidationError extends ValidationError {
constructor(message, cause) {
super(message, cause);
}
}
// File System Errors
class FileSystemError extends BaseError {
path;
constructor(message, path, cause) {
super(`${message}: ${path}`, cause);
this.path = path;
}
}
class DirectoryAccessError extends FileSystemError {
constructor(dirPath, cause) {
super('Unable to access directory', dirPath, cause);
}
}
class FileReadError extends FileSystemError {
constructor(filePath, cause) {
super('Error reading file', filePath, cause);
}
}
class FileImportError extends FileSystemError {
constructor(filePath, cause) {
super('Error importing file', filePath, cause);
}
}
// Tool Errors
class ToolError extends BaseError {
toolName;
constructor(message, toolName, cause) {
super(`${message}: ${toolName}`, cause);
this.toolName = toolName;
}
}
class InvalidToolError extends ToolError {
constructor(toolName, details, cause) {
super(`Invalid tool found: ${details}`, toolName, cause);
}
}
class ToolNotFoundError extends ToolError {
constructor(toolName, cause) {
super('Tool not found', toolName, cause);
}
}
class FunctionCallError extends ToolError {
constructor(functionName, details, cause) {
super(`Error calling function: ${details}`, functionName, cause);
}
}
// API Errors
class APIError extends BaseError {
payload;
constructor(message, payload, cause) {
super(`${message}. Payload: ${JSON.stringify(payload)}`, cause);
this.payload = payload;
}
}
class ToolCompletionError extends APIError {
constructor(payload, cause) {
super('Tool completion failed', payload, cause);
}
}
class ChatCompletionError extends APIError {
constructor(payload, cause) {
super('Chat completion failed', payload, cause);
}
}
// Storage Errors
class StorageError extends BaseError {
constructor(message, cause) {
super(message, cause);
}
}
class RedisError extends StorageError {
constructor(message, cause) {
super(`Redis error: ${message}`, cause);
}
}
class RedisConnectionError extends RedisError {
constructor(message, cause) {
super(message, cause);
}
}
class RedisKeyValidationError extends RedisError {
constructor(message, cause) {
super(`Key validation failed: ${message}`, cause);
}
}
/**
* @class ToolsRegistry
* @description Singleton class for managing the tools registry. Holds the currently loaded tools.
*/
class ToolsRegistry {
static instance = null;
static toolsDirPath = null;
/**
* Gets the current instance of the tools registry.
*/
static getInstance() {
return ToolsRegistry.instance;
}
/**
* Sets the instance of the tools registry.
*/
static setInstance(tools) {
ToolsRegistry.instance = tools;
}
}
/**
* Validates the function name, ensuring it meets OpenAI's requirements.
*/
const validateFunctionName = (name) => {
if (!name || typeof name !== 'string') {
throw new InvalidToolError(name, 'Function name must be a non-empty string');
}
if (name.length > 64) {
throw new InvalidToolError(name, 'Function name must not exceed 64 characters');
}
if (!/^[a-zA-Z0-9_-]+$/.test(name)) {
throw new InvalidToolError(name, 'Function name must contain only alphanumeric characters, underscores, and hyphens');
}
};
/**
* Validates the function parameters, ensuring they are a non-null object.
*/
const validateFunctionParameters = (params, name) => {
if (!params || typeof params !== 'object') {
throw new InvalidToolError(name, 'Function parameters must be a non-null object');
}
};
/**
* Validates the function definition,
* checking name, description, parameters, and strict flag.
*/
const validateFunctionDefinition = (func) => {
const { name, description, parameters, strict } = func;
if (!func || typeof func !== 'object') {
throw new InvalidToolError(name, 'Function definition must be a non-null object');
}
validateFunctionName(name);
if (description !== undefined && typeof description !== 'string') {
throw new InvalidToolError(name, 'Function description must be a string when provided');
}
if (parameters !== undefined) {
validateFunctionParameters(parameters, name);
}
if (strict !== undefined &&
strict !== null &&
typeof strict !== 'boolean') {
throw new InvalidToolError(name, 'Function strict flag must be a boolean when provided');
}
};
/**
* Validates a chat completion tool definition,
* ensuring it has the correct type and a valid function definition.
*/
const validateChatCompletionTool = (tool) => {
const { function: functionDefinition, function: { name }, type, } = tool;
if (!tool || typeof tool !== 'object') {
throw new InvalidToolError(name, 'Chat completion tool must be a non-null object');
}
if (type !== 'function') {
throw new InvalidToolError(name, 'Chat completion tool type must be "function"');
}
validateFunctionDefinition(functionDefinition);
};
/**
* Validates the configuration of tools, ensuring that all defined tools have corresponding implementations.
*/
const validateToolConfiguration = (fnDefinitions, functions) => {
for (const def of fnDefinitions) {
const functionName = def.function.name;
if (!functions[functionName]) {
throw new ToolNotFoundError(`Missing function implementation for tool: ${functionName}`);
}
}
};
/**
* Loads tool files (both definitions and implementations) from a specified directory.
*
* @param {string} dirPath - The path to the directory containing tool files.
* @returns {Promise<AgentTools>} A promise that resolves to the loaded agent tools.
* @throws {DirectoryAccessError | FileReadError | FileImportError | InvalidToolError | ToolNotFoundError} If an error occurs during loading.
*/
const loadToolsDirFunctions = async (dirPath) => {
try {
// Validate directory access
try {
await fs.access(dirPath);
}
catch (error) {
throw new DirectoryAccessError(dirPath, error instanceof Error ? error : undefined);
}
// Read directory contents
let files;
try {
files = await fs.readdir(dirPath);
}
catch (error) {
throw new FileReadError(dirPath, error instanceof Error ? error : undefined);
}
const toolDefinitions = [];
const toolFunctions = {};
// Process each file
for (const file of files) {
if (!file.endsWith('.js') && !file.endsWith('.ts'))
continue;
const fullPath = path.join(dirPath, file);
// Validate file status
try {
const stat = await fs.stat(fullPath);
if (!stat.isFile())
continue;
}
catch (error) {
throw new FileReadError(fullPath, error instanceof Error ? error : undefined);
}
// Import file contents
let fileFunctions;
try {
fileFunctions = await import(fullPath);
}
catch (error) {
throw new FileImportError(fullPath, error instanceof Error ? error : undefined);
}
// Process functions
const funcs = fileFunctions.default || fileFunctions;
for (const [fnName, fn] of Object.entries(funcs)) {
try {
if (typeof fn === 'function') {
toolFunctions[fnName] = fn;
}
else {
// Validate as tool definition
validateChatCompletionTool(fn);
toolDefinitions.push(fn);
}
}
catch (error) {
if (error instanceof InvalidToolError)
throw error;
throw new InvalidToolError(fnName, `Unexpected error validating tool: ${error instanceof Error
? error.message
: 'Unknown error'}`);
}
}
}
// Validate final configuration
validateToolConfiguration(toolDefinitions, toolFunctions);
const tools = { toolDefinitions, toolFunctions };
ToolsRegistry.setInstance(tools);
return tools;
}
catch (error) {
if (error instanceof DirectoryAccessError ||
error instanceof FileReadError ||
error instanceof FileImportError ||
error instanceof InvalidToolError ||
error instanceof ToolNotFoundError)
throw error;
throw new Error(`Unexpected error loading tools: ${error instanceof Error ? error.message : 'Unknown error'}`);
}
};
/**
* Imports and returns specific tool functions based on their names.
* Loads tools from the directory if they haven't been loaded yet.
*
* @param {string[]} toolNames - An array of tool names to import.
* @returns {Promise<ToolChoices>} A promise that resolves to the imported tool functions and choices.
* @throws {ValidationError | ToolNotFoundError | InvalidToolError} If the tools directory path is not set or if any requested tools are missing.
*/
const importToolFunctions = async (toolNames) => {
try {
if (!ToolsRegistry.toolsDirPath) {
throw new ValidationError('Tools directory path not set. Call loadToolsDirFunctions with your tools directory path first.');
}
const tools = ToolsRegistry.getInstance() ??
(await loadToolsDirFunctions(ToolsRegistry.toolsDirPath));
const toolChoices = toolNames
.map((toolName) => tools.toolDefinitions.find((tool) => tool.function.name === toolName))
.filter((tool) => tool !== undefined);
const missingTools = toolNames.filter((name) => !toolChoices.some((tool) => tool.function.name === name));
if (missingTools.length > 0) {
throw new ToolNotFoundError(`The following tools were not found: ${missingTools.join(', ')}`);
}
return {
toolFunctions: tools.toolFunctions,
toolChoices,
};
}
catch (error) {
if (error instanceof DirectoryAccessError ||
error instanceof FileReadError ||
error instanceof FileImportError ||
error instanceof InvalidToolError ||
error instanceof ValidationError ||
error instanceof ToolNotFoundError)
throw error;
throw new Error(`Failed to import tool functions: ${error instanceof Error ? error.message : 'Unknown error'}`);
}
};
/**
* Calculates the sum of prompt tokens, completion tokens, and total tokens from multiple `CompletionUsage` objects.
*/
const getTokensSum = (...usages) => {
const initialUsage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
};
return usages.reduce((accumulator, currentUsage) => {
return {
prompt_tokens: accumulator.prompt_tokens + (currentUsage?.prompt_tokens ?? 0),
completion_tokens: accumulator.completion_tokens +
(currentUsage?.completion_tokens ?? 0),
total_tokens: accumulator.total_tokens + (currentUsage?.total_tokens ?? 0),
};
}, initialUsage);
};
/**
* Extracts and aggregates the `usage` information from multiple `ChatCompletion` objects.
*
* @param {...ChatCompletion} completions - One or more `ChatCompletion` objects.
* @returns {CompletionUsage} A `CompletionUsage` object representing the aggregated usage data.
* Returns an object with all properties set to 0 if no completions are provided or if none of them have a usage property.
*/
const getCompletionsUsage = (...completions) => {
const usages = [];
for (const completion of completions) {
if (completion.usage) {
usages.push(completion.usage);
}
}
return getTokensSum(...usages);
};
const handleNResponses = (response, queryParams) => {
let responses = [];
const responseMessage = response.choices[0].message;
if (queryParams.n) {
for (const choice of response.choices) {
if (choice.message.content)
responses.push(choice.message.content);
}
}
else {
responses = [responseMessage.content ?? 'Response not received.'];
}
return responses;
};
const CONFIG = {
USER_ID_MAX_LENGTH: 64,
DEFAULT_CHAT_MAX_LENGTH: 100,
KEY_PREFIX: 'chat:',
DEFAULT_USER: 'default',
};
/**
* @class AgentStorage
* @description Manages chat history and session metadata persistence using Redis.
*/
class AgentStorage {
redisClient;
historyOptions;
constructor(client) {
if (!client) {
throw new ValidationError('Redis client must be provided');
}
this.redisClient = client;
}
/**
* Validates the user ID, returning a 'default' value if undefined
* and throwing errors for invalid formats.
*/
validateUserId(userId) {
if (!userId)
return CONFIG.DEFAULT_USER;
if (typeof userId !== 'string') {
throw new RedisKeyValidationError('User ID must be a string');
}
if (userId.length > CONFIG.USER_ID_MAX_LENGTH) {
throw new RedisKeyValidationError(`User ID exceeds maximum length of ${CONFIG.USER_ID_MAX_LENGTH} characters`);
}
if (!/^[a-zA-Z0-9_-]+$/.test(userId)) {
throw new RedisKeyValidationError('User ID contains invalid characters. Only alphanumeric, underscore, and hyphen are allowed');
}
return userId;
}
/**
* Generates the Redis key for a given user ID.
*/
getRedisKey(userId) {
return `${CONFIG.KEY_PREFIX}${this.validateUserId(userId)}`;
}
/**
* Filters out tool-related messages from the chat history.
*
* @param {ChatCompletionMessageParam[]} messages - The chat history messages.
* @returns {ChatCompletionMessageParam[]} The filtered messages.
*/
removeToolMessages(messages) {
return messages.filter((message) => message.role === 'user' ||
(message.role === 'assistant' && !message.tool_calls));
}
/**
* Removes tool messages that don't have a corresponding assistant or tool call ID.
*
* @param {ChatCompletionMessageParam[]} messages - The chat history messages.
* @returns {ChatCompletionMessageParam[]} The filtered messages.
*/
removeOrphanedToolMessages(messages) {
const toolCallIds = new Set();
const assistantCallIds = new Set();
messages.forEach((message) => {
if ('tool_call_id' in message) {
toolCallIds.add(message.tool_call_id);
}
else if ('tool_calls' in message) {
if (message.tool_calls)
message.tool_calls.forEach((toolCall) => {
assistantCallIds.add(toolCall.id);
});
}
});
return messages.filter((message) => {
if ('tool_call_id' in message) {
return assistantCallIds.has(message.tool_call_id);
}
else if ('tool_calls' in message) {
if (message.tool_calls) {
message.tool_calls = message.tool_calls.filter((toolCall) => toolCallIds.has(toolCall.id));
if (!message.tool_calls.length)
return false;
}
}
return true;
});
}
filterMessages(messages, options) {
let filteredMessages = [...messages];
if (filteredMessages[0].role === 'system')
filteredMessages.shift();
if (options.remove_tool_messages) {
filteredMessages = this.removeToolMessages(filteredMessages);
}
return filteredMessages;
}
async calculateHistoryLength(redisKey, messages, options) {
let savedHistoryLength = 0;
try {
savedHistoryLength = await this.redisClient.lLen(redisKey);
}
catch (error) {
throw new StorageError('Error getting history length', error instanceof Error ? error : undefined);
}
if (options.max_length &&
savedHistoryLength + messages.length > options.max_length) {
const length = messages.length > options.max_length
? 0
: options.max_length - messages.length - 1;
return length;
}
return CONFIG.DEFAULT_CHAT_MAX_LENGTH - messages.length - 1;
}
async saveChatHistory(messages, userId, options = {}) {
try {
const redisKey = this.getRedisKey(userId);
const filteredMessages = this.filterMessages(messages, options);
const length = await this.calculateHistoryLength(redisKey, filteredMessages, options);
const multi = this.redisClient.multi();
multi.lTrim(redisKey, 0, length);
try {
for (const message of filteredMessages) {
multi.lPush(redisKey, JSON.stringify(message));
}
}
catch (error) {
throw new MessageValidationError(`Invalid message in storage: ${error instanceof Error ? error.message : 'Unknown error'}`, error instanceof Error ? error : undefined);
}
if (options.ttl)
multi.expire(redisKey, options.ttl);
await multi.exec();
}
catch (error) {
if (error instanceof MessageValidationError ||
error instanceof RedisKeyValidationError)
throw error;
throw new StorageError(`Failed to save chat history: ${error instanceof Error ? error.message : 'Unknown error'}`, error instanceof Error ? error : undefined);
}
}
/**
* Retrieves the chat history from Redis.
*
* @param {string} userId - The user ID.
* @param {HistoryOptions} options - Options for retrieving the history.
* @returns {Promise<ChatCompletionMessageParam[]>} The retrieved chat history.
* @throws {StorageError} If retrieving the chat history fails.
* @throws {MessageValidationError} If a message is invalid.
* @throws {RedisKeyValidationError} If the user ID is invalid.
*/
async getChatHistory(userId, options = {}) {
try {
const key = this.getRedisKey(userId);
const { appended_messages, remove_tool_messages, send_tool_messages, max_length, } = options;
if (appended_messages === 0)
return [];
const messages = await this.redisClient.lRange(key, 0, appended_messages ? appended_messages - 1 : -1);
if (!messages.length)
return [];
let parsedMessages = [];
try {
parsedMessages = messages
.map((message) => {
return JSON.parse(message);
})
.reverse();
}
catch (error) {
throw new MessageValidationError(`Invalid message in storage: ${error instanceof Error ? error.message : 'Unknown error'}`, error instanceof Error ? error : undefined);
}
if (remove_tool_messages || send_tool_messages === false)
return this.removeToolMessages(parsedMessages);
if (max_length)
return this.removeOrphanedToolMessages(parsedMessages);
return parsedMessages;
}
catch (error) {
if (error instanceof MessageValidationError ||
error instanceof RedisKeyValidationError)
throw error;
throw new StorageError(`Failed to retrieve stored messages: ${error instanceof Error ? error.message : 'Unknown error'}`, error instanceof Error ? error : undefined);
}
}
/**
* Deletes the chat history from Redis for a given user ID.
*
* @param {string} userId - The user ID.
* @returns {Promise<number>}
* @throws {StorageError} If deleting the chat history fails.
* @throws {RedisKeyValidationError} If the user ID is invalid.
*/
async deleteHistory(userId) {
if (!userId.trim()) {
throw new ValidationError('User ID is required');
}
try {
const key = this.getRedisKey(userId);
const result = await this.redisClient.del(key);
return result > 0;
}
catch (error) {
if (error instanceof RedisKeyValidationError) {
throw error;
}
throw new StorageError(`Failed to delete chat history: ${error instanceof Error
? error.message
: 'Unknown error occurred'}`, error instanceof Error ? error : undefined);
}
}
}
/**
* @class OpenAIAgent
* @description Extends the OpenAI API client to manage chat completions, tool interactions, and persistent storage of conversation history. Provides methods for creating chat completions, managing tools, and interacting with a Redis storage.
*/
class OpenAIAgent extends OpenAI__default {
static REQUIRED_ENV_VARS = ['OPENAI_API_KEY'];
completionParams;
defaultHistoryMessages = null;
storage = null;
system_instruction = null;
/**
* @constructor
* @param {AgentOptions} agentOptions - Options for configuring the agent, including the model, system instruction, initial messages template, etc.
* @param {ClientOptions} [options] - Optional OpenAI client options.
* @throws {ValidationError} If the model is not specified in the agent options.
*/
constructor(agentOptions, options) {
OpenAIAgent.validateEnvironment();
super(options);
if (!agentOptions.model) {
throw new ValidationError('Model is required to initialize the agent instance');
}
if (agentOptions.system_instruction)
this.system_instruction = agentOptions.system_instruction;
delete agentOptions.system_instruction;
if (agentOptions.messages)
this.defaultHistoryMessages = agentOptions.messages;
delete agentOptions.messages;
this.completionParams = agentOptions;
}
get model() {
return this.completionParams.model;
}
set model(value) {
this.completionParams.model = value;
}
get temperature() {
return this.completionParams.temperature;
}
set temperature(value) {
this.completionParams.temperature = value;
}
get top_p() {
return this.completionParams.top_p;
}
set top_p(value) {
this.completionParams.top_p = value;
}
get max_completion_tokens() {
return this.completionParams.max_completion_tokens;
}
set max_completion_tokens(value) {
this.completionParams.max_completion_tokens = value;
}
get max_tokens() {
return this.completionParams.max_tokens;
}
set max_tokens(value) {
this.completionParams.max_tokens = value;
}
get n() {
return this.completionParams.n;
}
set n(value) {
this.completionParams.n = value;
}
get frequency_penalty() {
return this.completionParams.frequency_penalty;
}
set frequency_penalty(value) {
this.completionParams.frequency_penalty = value;
}
get presence_penalty() {
return this.completionParams.presence_penalty;
}
set presence_penalty(value) {
this.completionParams.presence_penalty = value;
}
get tool_choice() {
return this.completionParams.tool_choice;
}
set tool_choice(value) {
this.completionParams.tool_choice = value;
}
get parallel_tool_calls() {
return this.completionParams.parallel_tool_calls;
}
set parallel_tool_calls(value) {
this.completionParams.parallel_tool_calls = value;
}
get audioParams() {
return this.completionParams.audio;
}
set audioParams(value) {
this.completionParams.audio = value;
}
get response_format() {
return this.completionParams.response_format;
}
set response_format(value) {
this.completionParams.response_format = value;
}
get logit_bias() {
return this.completionParams.logit_bias;
}
set logit_bias(value) {
this.completionParams.logit_bias = value;
}
get logprobs() {
return this.completionParams.logprobs;
}
set logprobs(value) {
this.completionParams.logprobs = value;
}
get top_logprobs() {
return this.completionParams.top_logprobs;
}
set top_logprobs(value) {
this.completionParams.top_logprobs = value;
}
get metadata() {
return this.completionParams.metadata;
}
set metadata(value) {
this.completionParams.metadata = value;
}
get stop() {
return this.completionParams.stop;
}
set stop(value) {
this.completionParams.stop = value;
}
get modalities() {
return this.completionParams.modalities;
}
set modalities(value) {
this.completionParams.modalities = value;
}
get prediction() {
return this.completionParams.prediction;
}
set prediction(value) {
this.completionParams.prediction = value;
}
get seed() {
return this.completionParams.seed;
}
set seed(value) {
this.completionParams.seed = value;
}
get service_tier() {
return this.completionParams.service_tier;
}
set service_tier(value) {
this.completionParams.service_tier = value;
}
get store() {
return this.completionParams.store;
}
set store(value) {
this.completionParams.store = value;
}
/**
* Validates that required environment variables are set.
*/
static validateEnvironment() {
const missingVars = OpenAIAgent.REQUIRED_ENV_VARS.filter((varName) => !process.env[varName]);
if (missingVars.length > 0) {
throw new ValidationError(`Missing required environment variables: ${missingVars.join(', ')}`);
}
}
/**
* Determines the system instruction message to use based on default and custom instructions.
*/
handleSystemInstructionMessage(defaultInstruction, customInstruction) {
const systemInstructionMessage = {
role: 'system',
content: '',
};
if (defaultInstruction && !customInstruction) {
systemInstructionMessage.content = defaultInstruction;
}
else if (customInstruction) {
systemInstructionMessage.content = customInstruction;
}
return systemInstructionMessage;
}
/**
* Retrieves context messages from the default history and/or from persistent storage.
*/
async handleContextMessages(queryParams, historyOptions) {
const userId = queryParams.user ? queryParams.user : 'default';
let contextMessages = [];
if (this.defaultHistoryMessages) {
contextMessages = this.defaultHistoryMessages;
}
if (this.storage) {
const storedMessages = await this.storage.getChatHistory(userId, historyOptions);
contextMessages.push(...storedMessages);
}
console.log('Context length:', contextMessages.length);
return contextMessages;
}
/**
* Executes the functions called by the model and returns their responses.
*/
async callChosenFunctions(responseMessage, functions) {
if (!responseMessage.tool_calls?.length) {
throw new Error('No tool calls found in the response message');
}
const toolMessages = [];
for (const tool of responseMessage.tool_calls) {
const { id, function: { name, arguments: args }, } = tool;
try {
const currentFunction = functions[name];
if (!currentFunction) {
throw new Error(`Function '${name}' not found`);
}
let parsedArgs;
try {
parsedArgs = JSON.parse(args);
}
catch (error) {
console.error(error);
throw new Error(`Invalid arguments format for function '${name}': ${args}`);
}
const functionResponse = await Promise.resolve(currentFunction(parsedArgs));
if (functionResponse === undefined) {
throw new Error(`Function '${name}' returned no response`);
}
toolMessages.push({
tool_call_id: id,
role: 'tool',
content: JSON.stringify(functionResponse),
});
}
catch (error) {
toolMessages.push({
tool_call_id: id,
role: 'tool',
content: JSON.stringify({
error: error instanceof Error
? error.message
: 'Unknown error',
}),
});
throw new FunctionCallError(name, error instanceof Error ? error.message : 'Unknown error');
}
}
return toolMessages;
}
/**
* Handles the process of calling tools based on the model's response
* and making a subsequent API call with the tool responses.
*/
async handleToolCompletion(toolCompletionOpts) {
const { response, queryParams, newMessages, toolFunctions, historyOptions, } = toolCompletionOpts;
if (!queryParams?.messages?.length)
queryParams.messages = [];
const responseMessage = response.choices[0].message;
queryParams.messages.push(responseMessage);
try {
const toolMessages = await this.callChosenFunctions(responseMessage, toolFunctions);
queryParams.messages.push(...toolMessages);
newMessages.push(...toolMessages);
const secondResponse = await this.chat.completions.create(queryParams);
const secondResponseMessage = secondResponse.choices[0].message;
newMessages.push(secondResponseMessage);
if (this.storage) {
await this.storage.saveChatHistory(newMessages, queryParams.user, historyOptions);
}
const responses = handleNResponses(secondResponse, queryParams);
return {
choices: responses,
total_usage: getCompletionsUsage(response, secondResponse),
completion_messages: newMessages,
completions: [response, secondResponse],
};
}
catch (error) {
if (error instanceof FunctionCallError)
throw error;
throw new ToolCompletionError(queryParams, error instanceof Error ? error : undefined);
}
}
/**
* Creates a chat completion, handles tool calls (if any), and manages conversation history.
*
* @param {string} message - The user's message.
* @param {CreateChatCompletionOptions} [completionOptions] - Options for the chat completion, including custom parameters, tool choices, and history management.
* @returns {Promise<CompletionResult>} A promise that resolves to the chat completion result.
* @throws {ChatCompletionError | StorageError | RedisConnectionError | RedisKeyValidationError | MessageValidationError | DirectoryAccessError | FileReadError | FileImportError | InvalidToolError | ToolNotFoundError | ToolCompletionError | FunctionCallError | ValidationError} If an error occurs during the completion process.
*/
async createChatCompletion(message, completionOptions = {}) {
const queryParams = {
...this.completionParams,
...completionOptions.custom_params,
};
const historyOptions = {
...this.storage?.historyOptions,
...completionOptions.history,
};
let storedMessages = [];
try {
storedMessages = await this.handleContextMessages(queryParams, historyOptions);
if (this.storage &&
completionOptions.tool_choices &&
historyOptions?.appended_messages)
storedMessages =
this.storage.removeOrphanedToolMessages(storedMessages);
const systemImstructionMessage = this.handleSystemInstructionMessage(this.system_instruction, completionOptions.system_instruction);
if (systemImstructionMessage.content) {
// Overwrites the default instruction if there is a new instruction in the current request
if (storedMessages[0]?.role === 'system')
storedMessages.shift();
storedMessages.unshift(systemImstructionMessage);
}
const newMessages = [
{ role: 'user', content: message },
];
storedMessages.push(...newMessages);
queryParams.messages = storedMessages;
let toolFunctions;
if (completionOptions.tool_choices?.length) {
const toolChoices = await importToolFunctions(completionOptions.tool_choices);
queryParams.tools = toolChoices.toolChoices;
toolFunctions = toolChoices.toolFunctions;
}
const response = await this.chat.completions.create(queryParams);
const responseMessage = response.choices[0].message;
newMessages.push(responseMessage);
if (responseMessage.tool_calls && toolFunctions) {
return await this.handleToolCompletion({
response,
queryParams,
newMessages,
toolFunctions,
historyOptions,
});
}
else {
if (this.storage) {
await this.storage.saveChatHistory(newMessages, queryParams.user, historyOptions);
}
const responses = handleNResponses(response, queryParams);
return {
choices: responses,
total_usage: getCompletionsUsage(response),
completion_messages: newMessages,
completions: [response],
};
}
}
catch (error) {
if (error instanceof StorageError ||
error instanceof RedisConnectionError ||
error instanceof RedisKeyValidationError ||
error instanceof MessageValidationError ||
error instanceof DirectoryAccessError ||
error instanceof FileReadError ||
error instanceof FileImportError ||
error instanceof InvalidToolError ||
error instanceof ToolNotFoundError ||
error instanceof ToolCompletionError ||
error instanceof FunctionCallError ||
error instanceof ValidationError)
throw error;
throw new ChatCompletionError(queryParams, error instanceof Error ? error : undefined);
}
}
/**
* Loads tool functions from the specified directory.
*
* @param {string} toolsDirPath - The path to the directory containing the tool functions.
* @returns {Promise<boolean>} A promise that resolves to true if the tools are loaded successfully.
* @throws {ValidationError} If the tools directory path is not provided or invalid.
*/
async loadToolFuctions(toolsDirPath) {
if (!toolsDirPath)
throw new ValidationError('Tools directory path required.');
await loadToolsDirFunctions(toolsDirPath);
ToolsRegistry.toolsDirPath = toolsDirPath;
return true;
}
/**
* Sets up the storage using a Redis client.
*
* @param {RedisClientType} client - The Redis client instance.
* @param {{ history: HistoryOptions }} [options] - Options for configuring history storage.
* @returns {Promise<boolean>} A promise that resolves to true if the storage is set up successfully.
* @throws {RedisConnectionError} If the Redis client is not provided.
*/
async setStorage(client, options) {
if (!client)
throw new RedisConnectionError('Instance of Redis is required.');
this.storage = new AgentStorage(client);
if (options?.history)
this.storage.historyOptions = options.history;
return true;
}
/**
* Deletes the chat history for a given user.
*
* @param {string} userId - The ID of the user whose history should be deleted.
* @returns {Promise<boolean>} A promise that resolves to true if the history is deleted successfully.
* @throws {RedisConnectionError} If the storage is not initialized.
*/
async deleteChatHistory(userId) {
if (!this.storage) {
throw new RedisConnectionError('Agent storage is not initalized.');
}
await this.storage.deleteHistory(userId);
return true;
}
/**
* Retrieves the chat history for a given user.
*
* @param {string} userId - The ID of the user whose history should be retrieved.
* @param {HistoryOptions} [options] - Options for retrieving history.
* @returns {Promise<ChatCompletionMessageParam[]>} A promise that resolves to an array of chat messages.
* @throws {RedisConnectionError} If the storage is not initialized.
*/
async getChatHistory(userId) {
if (!this.storage)
throw new RedisConnectionError('Agent storage is not initialized');
const messages = await this.storage.getChatHistory(userId);
return messages;
}
}
export { AgentStorage, OpenAIAgent, getCompletionsUsage };