openai
Version:
The official TypeScript library for the OpenAI API
683 lines (599 loc) • 23.6 kB
text/typescript
import * as Core from "../core";
import { type CompletionUsage } from "../resources/completions";
import {
type Completions,
type ChatCompletion,
type ChatCompletionMessage,
type ChatCompletionMessageParam,
type ChatCompletionCreateParams,
type ChatCompletionTool,
} from "../resources/chat/completions";
import { APIUserAbortError, OpenAIError } from "../error";
import {
type RunnableFunction,
isRunnableFunctionWithParse,
type BaseFunctionsArgs,
} from './RunnableFunction';
import { ChatCompletionFunctionRunnerParams, ChatCompletionToolRunnerParams } from './ChatCompletionRunner';
import {
ChatCompletionStreamingFunctionRunnerParams,
ChatCompletionStreamingToolRunnerParams,
} from './ChatCompletionStreamingRunner';
import { isAssistantMessage, isFunctionMessage, isToolMessage } from './chatCompletionUtils';
const DEFAULT_MAX_CHAT_COMPLETIONS = 10;
export interface RunnerOptions extends Core.RequestOptions {
/** How many requests to make before canceling. Default 10. */
maxChatCompletions?: number;
}
export abstract class AbstractChatCompletionRunner<
Events extends CustomEvents<any> = AbstractChatCompletionRunnerEvents,
> {
controller: AbortController = new AbortController();
#connectedPromise: Promise<void>;
#resolveConnectedPromise: () => void = () => {};
#rejectConnectedPromise: (error: OpenAIError) => void = () => {};
#endPromise: Promise<void>;
#resolveEndPromise: () => void = () => {};
#rejectEndPromise: (error: OpenAIError) => void = () => {};
#listeners: { [Event in keyof Events]?: ListenersForEvent<Events, Event> } = {};
protected _chatCompletions: ChatCompletion[] = [];
messages: ChatCompletionMessageParam[] = [];
#ended = false;
#errored = false;
#aborted = false;
#catchingPromiseCreated = false;
constructor() {
this.#connectedPromise = new Promise<void>((resolve, reject) => {
this.#resolveConnectedPromise = resolve;
this.#rejectConnectedPromise = reject;
});
this.#endPromise = new Promise<void>((resolve, reject) => {
this.#resolveEndPromise = resolve;
this.#rejectEndPromise = reject;
});
// Don't let these promises cause unhandled rejection errors.
// we will manually cause an unhandled rejection error later
// if the user hasn't registered any error listener or called
// any promise-returning method.
this.#connectedPromise.catch(() => {});
this.#endPromise.catch(() => {});
}
protected _run(executor: () => Promise<any>) {
// Unfortunately if we call `executor()` immediately we get runtime errors about
// references to `this` before the `super()` constructor call returns.
setTimeout(() => {
executor().then(() => {
this._emitFinal();
this._emit('end');
}, this.#handleError);
}, 0);
}
protected _addChatCompletion(chatCompletion: ChatCompletion): ChatCompletion {
this._chatCompletions.push(chatCompletion);
this._emit('chatCompletion', chatCompletion);
const message = chatCompletion.choices[0]?.message;
if (message) this._addMessage(message as ChatCompletionMessageParam);
return chatCompletion;
}
protected _addMessage(message: ChatCompletionMessageParam, emit = true) {
if (!('content' in message)) message.content = null;
this.messages.push(message);
if (emit) {
this._emit('message', message);
if ((isFunctionMessage(message) || isToolMessage(message)) && message.content) {
// Note, this assumes that {role: 'tool', content: …} is always the result of a call of tool of type=function.
this._emit('functionCallResult', message.content as string);
} else if (isAssistantMessage(message) && message.function_call) {
this._emit('functionCall', message.function_call);
} else if (isAssistantMessage(message) && message.tool_calls) {
for (const tool_call of message.tool_calls) {
if (tool_call.type === 'function') {
this._emit('functionCall', tool_call.function);
}
}
}
}
}
protected _connected() {
if (this.ended) return;
this.#resolveConnectedPromise();
this._emit('connect');
}
get ended(): boolean {
return this.#ended;
}
get errored(): boolean {
return this.#errored;
}
get aborted(): boolean {
return this.#aborted;
}
abort() {
this.controller.abort();
}
/**
* Adds the listener function to the end of the listeners array for the event.
* No checks are made to see if the listener has already been added. Multiple calls passing
* the same combination of event and listener will result in the listener being added, and
* called, multiple times.
* @returns this ChatCompletionStream, so that calls can be chained
*/
on<Event extends keyof Events>(event: Event, listener: ListenerForEvent<Events, Event>): this {
const listeners: ListenersForEvent<Events, Event> =
this.#listeners[event] || (this.#listeners[event] = []);
listeners.push({ listener });
return this;
}
/**
* Removes the specified listener from the listener array for the event.
* off() will remove, at most, one instance of a listener from the listener array. If any single
* listener has been added multiple times to the listener array for the specified event, then
* off() must be called multiple times to remove each instance.
* @returns this ChatCompletionStream, so that calls can be chained
*/
off<Event extends keyof Events>(event: Event, listener: ListenerForEvent<Events, Event>): this {
const listeners = this.#listeners[event];
if (!listeners) return this;
const index = listeners.findIndex((l) => l.listener === listener);
if (index >= 0) listeners.splice(index, 1);
return this;
}
/**
* Adds a one-time listener function for the event. The next time the event is triggered,
* this listener is removed and then invoked.
* @returns this ChatCompletionStream, so that calls can be chained
*/
once<Event extends keyof Events>(event: Event, listener: ListenerForEvent<Events, Event>): this {
const listeners: ListenersForEvent<Events, Event> =
this.#listeners[event] || (this.#listeners[event] = []);
listeners.push({ listener, once: true });
return this;
}
/**
* This is similar to `.once()`, but returns a Promise that resolves the next time
* the event is triggered, instead of calling a listener callback.
* @returns a Promise that resolves the next time given event is triggered,
* or rejects if an error is emitted. (If you request the 'error' event,
* returns a promise that resolves with the error).
*
* Example:
*
* const message = await stream.emitted('message') // rejects if the stream errors
*/
emitted<Event extends keyof Events>(
event: Event,
): Promise<
EventParameters<Events, Event> extends [infer Param] ? Param
: EventParameters<Events, Event> extends [] ? void
: EventParameters<Events, Event>
> {
return new Promise((resolve, reject) => {
this.#catchingPromiseCreated = true;
if (event !== 'error') this.once('error', reject);
this.once(event, resolve as any);
});
}
async done(): Promise<void> {
this.#catchingPromiseCreated = true;
await this.#endPromise;
}
/**
* @returns a promise that resolves with the final ChatCompletion, or rejects
* if an error occurred or the stream ended prematurely without producing a ChatCompletion.
*/
async finalChatCompletion(): Promise<ChatCompletion> {
await this.done();
const completion = this._chatCompletions[this._chatCompletions.length - 1];
if (!completion) throw new OpenAIError('stream ended without producing a ChatCompletion');
return completion;
}
#getFinalContent(): string | null {
return this.#getFinalMessage().content ?? null;
}
/**
* @returns a promise that resolves with the content of the final ChatCompletionMessage, or rejects
* if an error occurred or the stream ended prematurely without producing a ChatCompletionMessage.
*/
async finalContent(): Promise<string | null> {
await this.done();
return this.#getFinalContent();
}
#getFinalMessage(): ChatCompletionMessage {
let i = this.messages.length;
while (i-- > 0) {
const message = this.messages[i];
if (isAssistantMessage(message)) {
return { ...message, content: message.content ?? null };
}
}
throw new OpenAIError('stream ended without producing a ChatCompletionMessage with role=assistant');
}
/**
* @returns a promise that resolves with the the final assistant ChatCompletionMessage response,
* or rejects if an error occurred or the stream ended prematurely without producing a ChatCompletionMessage.
*/
async finalMessage(): Promise<ChatCompletionMessage> {
await this.done();
return this.#getFinalMessage();
}
#getFinalFunctionCall(): ChatCompletionMessage.FunctionCall | undefined {
for (let i = this.messages.length - 1; i >= 0; i--) {
const message = this.messages[i];
if (isAssistantMessage(message) && message?.function_call) {
return message.function_call;
}
if (isAssistantMessage(message) && message?.tool_calls?.length) {
return message.tool_calls.at(-1)?.function;
}
}
return;
}
/**
* @returns a promise that resolves with the content of the final FunctionCall, or rejects
* if an error occurred or the stream ended prematurely without producing a ChatCompletionMessage.
*/
async finalFunctionCall(): Promise<ChatCompletionMessage.FunctionCall | undefined> {
await this.done();
return this.#getFinalFunctionCall();
}
#getFinalFunctionCallResult(): string | undefined {
for (let i = this.messages.length - 1; i >= 0; i--) {
const message = this.messages[i];
if (isFunctionMessage(message) && message.content != null) {
return message.content;
}
if (
isToolMessage(message) &&
message.content != null &&
this.messages.some(
(x) =>
x.role === 'assistant' &&
x.tool_calls?.some((y) => y.type === 'function' && y.id === message.tool_call_id),
)
) {
return message.content;
}
}
return;
}
async finalFunctionCallResult(): Promise<string | undefined> {
await this.done();
return this.#getFinalFunctionCallResult();
}
#calculateTotalUsage(): CompletionUsage {
const total: CompletionUsage = {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0,
};
for (const { usage } of this._chatCompletions) {
if (usage) {
total.completion_tokens += usage.completion_tokens;
total.prompt_tokens += usage.prompt_tokens;
total.total_tokens += usage.total_tokens;
}
}
return total;
}
async totalUsage(): Promise<CompletionUsage> {
await this.done();
return this.#calculateTotalUsage();
}
allChatCompletions(): ChatCompletion[] {
return [...this._chatCompletions];
}
#handleError = (error: unknown) => {
this.#errored = true;
if (error instanceof Error && error.name === 'AbortError') {
error = new APIUserAbortError();
}
if (error instanceof APIUserAbortError) {
this.#aborted = true;
return this._emit('abort', error);
}
if (error instanceof OpenAIError) {
return this._emit('error', error);
}
if (error instanceof Error) {
const openAIError: OpenAIError = new OpenAIError(error.message);
// @ts-ignore
openAIError.cause = error;
return this._emit('error', openAIError);
}
return this._emit('error', new OpenAIError(String(error)));
};
protected _emit<Event extends keyof Events>(event: Event, ...args: EventParameters<Events, Event>) {
// make sure we don't emit any events after end
if (this.#ended) {
return;
}
if (event === 'end') {
this.#ended = true;
this.#resolveEndPromise();
}
const listeners: ListenersForEvent<Events, Event> | undefined = this.#listeners[event];
if (listeners) {
this.#listeners[event] = listeners.filter((l) => !l.once) as any;
listeners.forEach(({ listener }: any) => listener(...args));
}
if (event === 'abort') {
const error = args[0] as APIUserAbortError;
if (!this.#catchingPromiseCreated && !listeners?.length) {
Promise.reject(error);
}
this.#rejectConnectedPromise(error);
this.#rejectEndPromise(error);
this._emit('end');
return;
}
if (event === 'error') {
// NOTE: _emit('error', error) should only be called from #handleError().
const error = args[0] as OpenAIError;
if (!this.#catchingPromiseCreated && !listeners?.length) {
// Trigger an unhandled rejection if the user hasn't registered any error handlers.
// If you are seeing stack traces here, make sure to handle errors via either:
// - runner.on('error', () => ...)
// - await runner.done()
// - await runner.finalChatCompletion()
// - etc.
Promise.reject(error);
}
this.#rejectConnectedPromise(error);
this.#rejectEndPromise(error);
this._emit('end');
}
}
protected _emitFinal() {
const completion = this._chatCompletions[this._chatCompletions.length - 1];
if (completion) this._emit('finalChatCompletion', completion);
const finalMessage = this.#getFinalMessage();
if (finalMessage) this._emit('finalMessage', finalMessage);
const finalContent = this.#getFinalContent();
if (finalContent) this._emit('finalContent', finalContent);
const finalFunctionCall = this.#getFinalFunctionCall();
if (finalFunctionCall) this._emit('finalFunctionCall', finalFunctionCall);
const finalFunctionCallResult = this.#getFinalFunctionCallResult();
if (finalFunctionCallResult != null) this._emit('finalFunctionCallResult', finalFunctionCallResult);
if (this._chatCompletions.some((c) => c.usage)) {
this._emit('totalUsage', this.#calculateTotalUsage());
}
}
#validateParams(params: ChatCompletionCreateParams): void {
if (params.n != null && params.n > 1) {
throw new OpenAIError(
'ChatCompletion convenience helpers only support n=1 at this time. To use n>1, please use chat.completions.create() directly.',
);
}
}
protected async _createChatCompletion(
completions: Completions,
params: ChatCompletionCreateParams,
options?: Core.RequestOptions,
): Promise<ChatCompletion> {
const signal = options?.signal;
if (signal) {
if (signal.aborted) this.controller.abort();
signal.addEventListener('abort', () => this.controller.abort());
}
this.#validateParams(params);
const chatCompletion = await completions.create(
{ ...params, stream: false },
{ ...options, signal: this.controller.signal },
);
this._connected();
return this._addChatCompletion(chatCompletion);
}
protected async _runChatCompletion(
completions: Completions,
params: ChatCompletionCreateParams,
options?: Core.RequestOptions,
): Promise<ChatCompletion> {
for (const message of params.messages) {
this._addMessage(message, false);
}
return await this._createChatCompletion(completions, params, options);
}
protected async _runFunctions<FunctionsArgs extends BaseFunctionsArgs>(
completions: Completions,
params:
| ChatCompletionFunctionRunnerParams<FunctionsArgs>
| ChatCompletionStreamingFunctionRunnerParams<FunctionsArgs>,
options?: RunnerOptions,
) {
const role = 'function' as const;
const { function_call = 'auto', stream, ...restParams } = params;
const singleFunctionToCall = typeof function_call !== 'string' && function_call?.name;
const { maxChatCompletions = DEFAULT_MAX_CHAT_COMPLETIONS } = options || {};
const functionsByName: Record<string, RunnableFunction<any>> = {};
for (const f of params.functions) {
functionsByName[f.name || f.function.name] = f;
}
const functions: ChatCompletionCreateParams.Function[] = params.functions.map(
(f): ChatCompletionCreateParams.Function => ({
name: f.name || f.function.name,
parameters: f.parameters as Record<string, unknown>,
description: f.description,
}),
);
for (const message of params.messages) {
this._addMessage(message, false);
}
for (let i = 0; i < maxChatCompletions; ++i) {
const chatCompletion: ChatCompletion = await this._createChatCompletion(
completions,
{
...restParams,
function_call,
functions,
messages: [...this.messages],
},
options,
);
const message = chatCompletion.choices[0]?.message;
if (!message) {
throw new OpenAIError(`missing message in ChatCompletion response`);
}
if (!message.function_call) return;
const { name, arguments: args } = message.function_call;
const fn = functionsByName[name];
if (!fn) {
const content = `Invalid function_call: ${JSON.stringify(name)}. Available options are: ${functions
.map((f) => JSON.stringify(f.name))
.join(', ')}. Please try again`;
this._addMessage({ role, name, content });
continue;
} else if (singleFunctionToCall && singleFunctionToCall !== name) {
const content = `Invalid function_call: ${JSON.stringify(name)}. ${JSON.stringify(
singleFunctionToCall,
)} requested. Please try again`;
this._addMessage({ role, name, content });
continue;
}
let parsed;
try {
parsed = isRunnableFunctionWithParse(fn) ? await fn.parse(args) : args;
} catch (error) {
this._addMessage({
role,
name,
content: error instanceof Error ? error.message : String(error),
});
continue;
}
// @ts-expect-error it can't rule out `never` type.
const rawContent = await fn.function(parsed, this);
const content = this.#stringifyFunctionCallResult(rawContent);
this._addMessage({ role, name, content });
if (singleFunctionToCall) return;
}
}
protected async _runTools<FunctionsArgs extends BaseFunctionsArgs>(
completions: Completions,
params:
| ChatCompletionToolRunnerParams<FunctionsArgs>
| ChatCompletionStreamingToolRunnerParams<FunctionsArgs>,
options?: RunnerOptions,
) {
const role = 'tool' as const;
const { tool_choice = 'auto', stream, ...restParams } = params;
const singleFunctionToCall = typeof tool_choice !== 'string' && tool_choice?.function?.name;
const { maxChatCompletions = DEFAULT_MAX_CHAT_COMPLETIONS } = options || {};
const functionsByName: Record<string, RunnableFunction<any>> = {};
for (const f of params.tools) {
if (f.type === 'function') {
functionsByName[f.function.name || f.function.function.name] = f.function;
}
}
const tools: ChatCompletionTool[] =
'tools' in params ?
params.tools.map((t) =>
t.type === 'function' ?
{
type: 'function',
function: {
name: t.function.name || t.function.function.name,
parameters: t.function.parameters as Record<string, unknown>,
description: t.function.description,
},
}
: (t as unknown as ChatCompletionTool),
)
: (undefined as any);
for (const message of params.messages) {
this._addMessage(message, false);
}
for (let i = 0; i < maxChatCompletions; ++i) {
const chatCompletion: ChatCompletion = await this._createChatCompletion(
completions,
{
...restParams,
tool_choice,
tools,
messages: [...this.messages],
},
options,
);
const message = chatCompletion.choices[0]?.message;
if (!message) {
throw new OpenAIError(`missing message in ChatCompletion response`);
}
if (!message.tool_calls) {
return;
}
for (const tool_call of message.tool_calls) {
if (tool_call.type !== 'function') continue;
const tool_call_id = tool_call.id;
const { name, arguments: args } = tool_call.function;
const fn = functionsByName[name];
if (!fn) {
const content = `Invalid tool_call: ${JSON.stringify(name)}. Available options are: ${tools
.map((f) => JSON.stringify(f.function.name))
.join(', ')}. Please try again`;
this._addMessage({ role, tool_call_id, content });
continue;
} else if (singleFunctionToCall && singleFunctionToCall !== name) {
const content = `Invalid tool_call: ${JSON.stringify(name)}. ${JSON.stringify(
singleFunctionToCall,
)} requested. Please try again`;
this._addMessage({ role, tool_call_id, content });
continue;
}
let parsed;
try {
parsed = isRunnableFunctionWithParse(fn) ? await fn.parse(args) : args;
} catch (error) {
const content = error instanceof Error ? error.message : String(error);
this._addMessage({ role, tool_call_id, content });
continue;
}
// @ts-expect-error it can't rule out `never` type.
const rawContent = await fn.function(parsed, this);
const content = this.#stringifyFunctionCallResult(rawContent);
this._addMessage({ role, tool_call_id, content });
if (singleFunctionToCall) {
return;
}
}
}
return;
}
#stringifyFunctionCallResult(rawContent: unknown): string {
return (
typeof rawContent === 'string' ? rawContent
: rawContent === undefined ? 'undefined'
: JSON.stringify(rawContent)
);
}
}
type CustomEvents<Event extends string> = {
[k in Event]: k extends keyof AbstractChatCompletionRunnerEvents ? AbstractChatCompletionRunnerEvents[k]
: (...args: any[]) => void;
};
type ListenerForEvent<Events extends CustomEvents<any>, Event extends keyof Events> = Event extends (
keyof AbstractChatCompletionRunnerEvents
) ?
AbstractChatCompletionRunnerEvents[Event]
: Events[Event];
type ListenersForEvent<Events extends CustomEvents<any>, Event extends keyof Events> = Array<{
listener: ListenerForEvent<Events, Event>;
once?: boolean;
}>;
type EventParameters<Events extends CustomEvents<any>, Event extends keyof Events> = Parameters<
ListenerForEvent<Events, Event>
>;
export interface AbstractChatCompletionRunnerEvents {
connect: () => void;
functionCall: (functionCall: ChatCompletionMessage.FunctionCall) => void;
message: (message: ChatCompletionMessageParam) => void;
chatCompletion: (completion: ChatCompletion) => void;
finalContent: (contentSnapshot: string) => void;
finalMessage: (message: ChatCompletionMessageParam) => void;
finalChatCompletion: (completion: ChatCompletion) => void;
finalFunctionCall: (functionCall: ChatCompletionMessage.FunctionCall) => void;
functionCallResult: (content: string) => void;
finalFunctionCallResult: (content: string) => void;
error: (error: OpenAIError) => void;
abort: (error: APIUserAbortError) => void;
end: () => void;
totalUsage: (usage: CompletionUsage) => void;
}