openai
Version:
The official TypeScript library for the OpenAI API
780 lines (670 loc) • 24.5 kB
text/typescript
import {
TextContentBlock,
ImageFileContentBlock,
Message,
MessageContentDelta,
Text,
ImageFile,
TextDelta,
MessageDelta,
MessageContent,
} from '../resources/beta/threads/messages';
import * as Core from '../core';
import { RequestOptions } from '../core';
import {
Run,
RunCreateParamsBase,
RunCreateParamsStreaming,
Runs,
RunSubmitToolOutputsParamsBase,
RunSubmitToolOutputsParamsStreaming,
} from '../resources/beta/threads/runs/runs';
import { type ReadableStream } from '../_shims/index';
import { Stream } from '../streaming';
import { APIUserAbortError, OpenAIError } from '../error';
import {
AssistantStreamEvent,
MessageStreamEvent,
RunStepStreamEvent,
RunStreamEvent,
} from '../resources/beta/assistants';
import { RunStep, RunStepDelta, ToolCall, ToolCallDelta } from '../resources/beta/threads/runs/steps';
import { ThreadCreateAndRunParamsBase, Threads } from '../resources/beta/threads/threads';
import { BaseEvents, EventStream } from './EventStream';
export interface AssistantStreamEvents extends BaseEvents {
run: (run: Run) => void;
//New event structure
messageCreated: (message: Message) => void;
messageDelta: (message: MessageDelta, snapshot: Message) => void;
messageDone: (message: Message) => void;
runStepCreated: (runStep: RunStep) => void;
runStepDelta: (delta: RunStepDelta, snapshot: Runs.RunStep) => void;
runStepDone: (runStep: Runs.RunStep, snapshot: Runs.RunStep) => void;
toolCallCreated: (toolCall: ToolCall) => void;
toolCallDelta: (delta: ToolCallDelta, snapshot: ToolCall) => void;
toolCallDone: (toolCall: ToolCall) => void;
textCreated: (content: Text) => void;
textDelta: (delta: TextDelta, snapshot: Text) => void;
textDone: (content: Text, snapshot: Message) => void;
//No created or delta as this is not streamed
imageFileDone: (content: ImageFile, snapshot: Message) => void;
event: (event: AssistantStreamEvent) => void;
}
export type ThreadCreateAndRunParamsBaseStream = Omit<ThreadCreateAndRunParamsBase, 'stream'> & {
stream?: true;
};
export type RunCreateParamsBaseStream = Omit<RunCreateParamsBase, 'stream'> & {
stream?: true;
};
export type RunSubmitToolOutputsParamsStream = Omit<RunSubmitToolOutputsParamsBase, 'stream'> & {
stream?: true;
};
export class AssistantStream
extends EventStream<AssistantStreamEvents>
implements AsyncIterable<AssistantStreamEvent>
{
//Track all events in a single list for reference
#events: AssistantStreamEvent[] = [];
//Used to accumulate deltas
//We are accumulating many types so the value here is not strict
#runStepSnapshots: { [id: string]: Runs.RunStep } = {};
#messageSnapshots: { [id: string]: Message } = {};
#messageSnapshot: Message | undefined;
#finalRun: Run | undefined;
#currentContentIndex: number | undefined;
#currentContent: MessageContent | undefined;
#currentToolCallIndex: number | undefined;
#currentToolCall: ToolCall | undefined;
//For current snapshot methods
#currentEvent: AssistantStreamEvent | undefined;
#currentRunSnapshot: Run | undefined;
#currentRunStepSnapshot: Runs.RunStep | undefined;
[Symbol.asyncIterator](): AsyncIterator<AssistantStreamEvent> {
const pushQueue: AssistantStreamEvent[] = [];
const readQueue: {
resolve: (chunk: AssistantStreamEvent | undefined) => void;
reject: (err: unknown) => void;
}[] = [];
let done = false;
//Catch all for passing along all events
this.on('event', (event) => {
const reader = readQueue.shift();
if (reader) {
reader.resolve(event);
} else {
pushQueue.push(event);
}
});
this.on('end', () => {
done = true;
for (const reader of readQueue) {
reader.resolve(undefined);
}
readQueue.length = 0;
});
this.on('abort', (err) => {
done = true;
for (const reader of readQueue) {
reader.reject(err);
}
readQueue.length = 0;
});
this.on('error', (err) => {
done = true;
for (const reader of readQueue) {
reader.reject(err);
}
readQueue.length = 0;
});
return {
next: async (): Promise<IteratorResult<AssistantStreamEvent>> => {
if (!pushQueue.length) {
if (done) {
return { value: undefined, done: true };
}
return new Promise<AssistantStreamEvent | undefined>((resolve, reject) =>
readQueue.push({ resolve, reject }),
).then((chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true }));
}
const chunk = pushQueue.shift()!;
return { value: chunk, done: false };
},
return: async () => {
this.abort();
return { value: undefined, done: true };
},
};
}
static fromReadableStream(stream: ReadableStream): AssistantStream {
const runner = new AssistantStream();
runner._run(() => runner._fromReadableStream(stream));
return runner;
}
protected async _fromReadableStream(
readableStream: ReadableStream,
options?: Core.RequestOptions,
): Promise<Run> {
const signal = options?.signal;
if (signal) {
if (signal.aborted) this.controller.abort();
signal.addEventListener('abort', () => this.controller.abort());
}
this._connected();
const stream = Stream.fromReadableStream<AssistantStreamEvent>(readableStream, this.controller);
for await (const event of stream) {
this.#addEvent(event);
}
if (stream.controller.signal?.aborted) {
throw new APIUserAbortError();
}
return this._addRun(this.#endRequest());
}
toReadableStream(): ReadableStream {
const stream = new Stream(this[Symbol.asyncIterator].bind(this), this.controller);
return stream.toReadableStream();
}
static createToolAssistantStream(
threadId: string,
runId: string,
runs: Runs,
params: RunSubmitToolOutputsParamsStream,
options: RequestOptions | undefined,
): AssistantStream {
const runner = new AssistantStream();
runner._run(() =>
runner._runToolAssistantStream(threadId, runId, runs, params, {
...options,
headers: { ...options?.headers, 'X-Stainless-Helper-Method': 'stream' },
}),
);
return runner;
}
protected async _createToolAssistantStream(
run: Runs,
threadId: string,
runId: string,
params: RunSubmitToolOutputsParamsStream,
options?: Core.RequestOptions,
): Promise<Run> {
const signal = options?.signal;
if (signal) {
if (signal.aborted) this.controller.abort();
signal.addEventListener('abort', () => this.controller.abort());
}
const body: RunSubmitToolOutputsParamsStreaming = { ...params, stream: true };
const stream = await run.submitToolOutputs(threadId, runId, body, {
...options,
signal: this.controller.signal,
});
this._connected();
for await (const event of stream) {
this.#addEvent(event);
}
if (stream.controller.signal?.aborted) {
throw new APIUserAbortError();
}
return this._addRun(this.#endRequest());
}
static createThreadAssistantStream(
params: ThreadCreateAndRunParamsBaseStream,
thread: Threads,
options?: RequestOptions,
): AssistantStream {
const runner = new AssistantStream();
runner._run(() =>
runner._threadAssistantStream(params, thread, {
...options,
headers: { ...options?.headers, 'X-Stainless-Helper-Method': 'stream' },
}),
);
return runner;
}
static createAssistantStream(
threadId: string,
runs: Runs,
params: RunCreateParamsBaseStream,
options?: RequestOptions,
): AssistantStream {
const runner = new AssistantStream();
runner._run(() =>
runner._runAssistantStream(threadId, runs, params, {
...options,
headers: { ...options?.headers, 'X-Stainless-Helper-Method': 'stream' },
}),
);
return runner;
}
currentEvent(): AssistantStreamEvent | undefined {
return this.#currentEvent;
}
currentRun(): Run | undefined {
return this.#currentRunSnapshot;
}
currentMessageSnapshot(): Message | undefined {
return this.#messageSnapshot;
}
currentRunStepSnapshot(): Runs.RunStep | undefined {
return this.#currentRunStepSnapshot;
}
async finalRunSteps(): Promise<Runs.RunStep[]> {
await this.done();
return Object.values(this.#runStepSnapshots);
}
async finalMessages(): Promise<Message[]> {
await this.done();
return Object.values(this.#messageSnapshots);
}
async finalRun(): Promise<Run> {
await this.done();
if (!this.#finalRun) throw Error('Final run was not received.');
return this.#finalRun;
}
protected async _createThreadAssistantStream(
thread: Threads,
params: ThreadCreateAndRunParamsBase,
options?: Core.RequestOptions,
): Promise<Run> {
const signal = options?.signal;
if (signal) {
if (signal.aborted) this.controller.abort();
signal.addEventListener('abort', () => this.controller.abort());
}
const body: RunCreateParamsStreaming = { ...params, stream: true };
const stream = await thread.createAndRun(body, { ...options, signal: this.controller.signal });
this._connected();
for await (const event of stream) {
this.#addEvent(event);
}
if (stream.controller.signal?.aborted) {
throw new APIUserAbortError();
}
return this._addRun(this.#endRequest());
}
protected async _createAssistantStream(
run: Runs,
threadId: string,
params: RunCreateParamsBase,
options?: Core.RequestOptions,
): Promise<Run> {
const signal = options?.signal;
if (signal) {
if (signal.aborted) this.controller.abort();
signal.addEventListener('abort', () => this.controller.abort());
}
const body: RunCreateParamsStreaming = { ...params, stream: true };
const stream = await run.create(threadId, body, { ...options, signal: this.controller.signal });
this._connected();
for await (const event of stream) {
this.#addEvent(event);
}
if (stream.controller.signal?.aborted) {
throw new APIUserAbortError();
}
return this._addRun(this.#endRequest());
}
#addEvent(event: AssistantStreamEvent) {
if (this.ended) return;
this.#currentEvent = event;
this.#handleEvent(event);
switch (event.event) {
case 'thread.created':
//No action on this event.
break;
case 'thread.run.created':
case 'thread.run.queued':
case 'thread.run.in_progress':
case 'thread.run.requires_action':
case 'thread.run.completed':
case 'thread.run.incomplete':
case 'thread.run.failed':
case 'thread.run.cancelling':
case 'thread.run.cancelled':
case 'thread.run.expired':
this.#handleRun(event);
break;
case 'thread.run.step.created':
case 'thread.run.step.in_progress':
case 'thread.run.step.delta':
case 'thread.run.step.completed':
case 'thread.run.step.failed':
case 'thread.run.step.cancelled':
case 'thread.run.step.expired':
this.#handleRunStep(event);
break;
case 'thread.message.created':
case 'thread.message.in_progress':
case 'thread.message.delta':
case 'thread.message.completed':
case 'thread.message.incomplete':
this.#handleMessage(event);
break;
case 'error':
//This is included for completeness, but errors are processed in the SSE event processing so this should not occur
throw new Error(
'Encountered an error event in event processing - errors should be processed earlier',
);
default:
assertNever(event);
}
}
#endRequest(): Run {
if (this.ended) {
throw new OpenAIError(`stream has ended, this shouldn't happen`);
}
if (!this.#finalRun) throw Error('Final run has not been received');
return this.#finalRun;
}
#handleMessage(this: AssistantStream, event: MessageStreamEvent) {
const [accumulatedMessage, newContent] = this.#accumulateMessage(event, this.#messageSnapshot);
this.#messageSnapshot = accumulatedMessage;
this.#messageSnapshots[accumulatedMessage.id] = accumulatedMessage;
for (const content of newContent) {
const snapshotContent = accumulatedMessage.content[content.index];
if (snapshotContent?.type == 'text') {
this._emit('textCreated', snapshotContent.text);
}
}
switch (event.event) {
case 'thread.message.created':
this._emit('messageCreated', event.data);
break;
case 'thread.message.in_progress':
break;
case 'thread.message.delta':
this._emit('messageDelta', event.data.delta, accumulatedMessage);
if (event.data.delta.content) {
for (const content of event.data.delta.content) {
//If it is text delta, emit a text delta event
if (content.type == 'text' && content.text) {
let textDelta = content.text;
let snapshot = accumulatedMessage.content[content.index];
if (snapshot && snapshot.type == 'text') {
this._emit('textDelta', textDelta, snapshot.text);
} else {
throw Error('The snapshot associated with this text delta is not text or missing');
}
}
if (content.index != this.#currentContentIndex) {
//See if we have in progress content
if (this.#currentContent) {
switch (this.#currentContent.type) {
case 'text':
this._emit('textDone', this.#currentContent.text, this.#messageSnapshot);
break;
case 'image_file':
this._emit('imageFileDone', this.#currentContent.image_file, this.#messageSnapshot);
break;
}
}
this.#currentContentIndex = content.index;
}
this.#currentContent = accumulatedMessage.content[content.index];
}
}
break;
case 'thread.message.completed':
case 'thread.message.incomplete':
//We emit the latest content we were working on on completion (including incomplete)
if (this.#currentContentIndex !== undefined) {
const currentContent = event.data.content[this.#currentContentIndex];
if (currentContent) {
switch (currentContent.type) {
case 'image_file':
this._emit('imageFileDone', currentContent.image_file, this.#messageSnapshot);
break;
case 'text':
this._emit('textDone', currentContent.text, this.#messageSnapshot);
break;
}
}
}
if (this.#messageSnapshot) {
this._emit('messageDone', event.data);
}
this.#messageSnapshot = undefined;
}
}
#handleRunStep(this: AssistantStream, event: RunStepStreamEvent) {
const accumulatedRunStep = this.#accumulateRunStep(event);
this.#currentRunStepSnapshot = accumulatedRunStep;
switch (event.event) {
case 'thread.run.step.created':
this._emit('runStepCreated', event.data);
break;
case 'thread.run.step.delta':
const delta = event.data.delta;
if (
delta.step_details &&
delta.step_details.type == 'tool_calls' &&
delta.step_details.tool_calls &&
accumulatedRunStep.step_details.type == 'tool_calls'
) {
for (const toolCall of delta.step_details.tool_calls) {
if (toolCall.index == this.#currentToolCallIndex) {
this._emit(
'toolCallDelta',
toolCall,
accumulatedRunStep.step_details.tool_calls[toolCall.index] as ToolCall,
);
} else {
if (this.#currentToolCall) {
this._emit('toolCallDone', this.#currentToolCall);
}
this.#currentToolCallIndex = toolCall.index;
this.#currentToolCall = accumulatedRunStep.step_details.tool_calls[toolCall.index];
if (this.#currentToolCall) this._emit('toolCallCreated', this.#currentToolCall);
}
}
}
this._emit('runStepDelta', event.data.delta, accumulatedRunStep);
break;
case 'thread.run.step.completed':
case 'thread.run.step.failed':
case 'thread.run.step.cancelled':
case 'thread.run.step.expired':
this.#currentRunStepSnapshot = undefined;
const details = event.data.step_details;
if (details.type == 'tool_calls') {
if (this.#currentToolCall) {
this._emit('toolCallDone', this.#currentToolCall as ToolCall);
this.#currentToolCall = undefined;
}
}
this._emit('runStepDone', event.data, accumulatedRunStep);
break;
case 'thread.run.step.in_progress':
break;
}
}
#handleEvent(this: AssistantStream, event: AssistantStreamEvent) {
this.#events.push(event);
this._emit('event', event);
}
#accumulateRunStep(event: RunStepStreamEvent): Runs.RunStep {
switch (event.event) {
case 'thread.run.step.created':
this.#runStepSnapshots[event.data.id] = event.data;
return event.data;
case 'thread.run.step.delta':
let snapshot = this.#runStepSnapshots[event.data.id] as Runs.RunStep;
if (!snapshot) {
throw Error('Received a RunStepDelta before creation of a snapshot');
}
let data = event.data;
if (data.delta) {
const accumulated = AssistantStream.accumulateDelta(snapshot, data.delta) as Runs.RunStep;
this.#runStepSnapshots[event.data.id] = accumulated;
}
return this.#runStepSnapshots[event.data.id] as Runs.RunStep;
case 'thread.run.step.completed':
case 'thread.run.step.failed':
case 'thread.run.step.cancelled':
case 'thread.run.step.expired':
case 'thread.run.step.in_progress':
this.#runStepSnapshots[event.data.id] = event.data;
break;
}
if (this.#runStepSnapshots[event.data.id]) return this.#runStepSnapshots[event.data.id] as Runs.RunStep;
throw new Error('No snapshot available');
}
#accumulateMessage(
event: AssistantStreamEvent,
snapshot: Message | undefined,
): [Message, MessageContentDelta[]] {
let newContent: MessageContentDelta[] = [];
switch (event.event) {
case 'thread.message.created':
//On creation the snapshot is just the initial message
return [event.data, newContent];
case 'thread.message.delta':
if (!snapshot) {
throw Error(
'Received a delta with no existing snapshot (there should be one from message creation)',
);
}
let data = event.data;
//If this delta does not have content, nothing to process
if (data.delta.content) {
for (const contentElement of data.delta.content) {
if (contentElement.index in snapshot.content) {
let currentContent = snapshot.content[contentElement.index];
snapshot.content[contentElement.index] = this.#accumulateContent(
contentElement,
currentContent,
);
} else {
snapshot.content[contentElement.index] = contentElement as MessageContent;
// This is a new element
newContent.push(contentElement);
}
}
}
return [snapshot, newContent];
case 'thread.message.in_progress':
case 'thread.message.completed':
case 'thread.message.incomplete':
//No changes on other thread events
if (snapshot) {
return [snapshot, newContent];
} else {
throw Error('Received thread message event with no existing snapshot');
}
}
throw Error('Tried to accumulate a non-message event');
}
#accumulateContent(
contentElement: MessageContentDelta,
currentContent: MessageContent | undefined,
): TextContentBlock | ImageFileContentBlock {
return AssistantStream.accumulateDelta(currentContent as unknown as Record<any, any>, contentElement) as
| TextContentBlock
| ImageFileContentBlock;
}
static accumulateDelta(acc: Record<string, any>, delta: Record<string, any>): Record<string, any> {
for (const [key, deltaValue] of Object.entries(delta)) {
if (!acc.hasOwnProperty(key)) {
acc[key] = deltaValue;
continue;
}
let accValue = acc[key];
if (accValue === null || accValue === undefined) {
acc[key] = deltaValue;
continue;
}
// We don't accumulate these special properties
if (key === 'index' || key === 'type') {
acc[key] = deltaValue;
continue;
}
// Type-specific accumulation logic
if (typeof accValue === 'string' && typeof deltaValue === 'string') {
accValue += deltaValue;
} else if (typeof accValue === 'number' && typeof deltaValue === 'number') {
accValue += deltaValue;
} else if (Core.isObj(accValue) && Core.isObj(deltaValue)) {
accValue = this.accumulateDelta(accValue as Record<string, any>, deltaValue as Record<string, any>);
} else if (Array.isArray(accValue) && Array.isArray(deltaValue)) {
if (accValue.every((x) => typeof x === 'string' || typeof x === 'number')) {
accValue.push(...deltaValue); // Use spread syntax for efficient addition
continue;
}
for (const deltaEntry of deltaValue) {
if (!Core.isObj(deltaEntry)) {
throw new Error(`Expected array delta entry to be an object but got: ${deltaEntry}`);
}
const index = deltaEntry['index'];
if (index == null) {
console.error(deltaEntry);
throw new Error('Expected array delta entry to have an `index` property');
}
if (typeof index !== 'number') {
throw new Error(`Expected array delta entry \`index\` property to be a number but got ${index}`);
}
const accEntry = accValue[index];
if (accEntry == null) {
accValue.push(deltaEntry);
} else {
accValue[index] = this.accumulateDelta(accEntry, deltaEntry);
}
}
continue;
} else {
throw Error(`Unhandled record type: ${key}, deltaValue: ${deltaValue}, accValue: ${accValue}`);
}
acc[key] = accValue;
}
return acc;
}
#handleRun(this: AssistantStream, event: RunStreamEvent) {
this.#currentRunSnapshot = event.data;
switch (event.event) {
case 'thread.run.created':
break;
case 'thread.run.queued':
break;
case 'thread.run.in_progress':
break;
case 'thread.run.requires_action':
case 'thread.run.cancelled':
case 'thread.run.failed':
case 'thread.run.completed':
case 'thread.run.expired':
this.#finalRun = event.data;
if (this.#currentToolCall) {
this._emit('toolCallDone', this.#currentToolCall);
this.#currentToolCall = undefined;
}
break;
case 'thread.run.cancelling':
break;
}
}
protected _addRun(run: Run): Run {
return run;
}
protected async _threadAssistantStream(
params: ThreadCreateAndRunParamsBase,
thread: Threads,
options?: Core.RequestOptions,
): Promise<Run> {
return await this._createThreadAssistantStream(thread, params, options);
}
protected async _runAssistantStream(
threadId: string,
runs: Runs,
params: RunCreateParamsBase,
options?: Core.RequestOptions,
): Promise<Run> {
return await this._createAssistantStream(runs, threadId, params, options);
}
protected async _runToolAssistantStream(
threadId: string,
runId: string,
runs: Runs,
params: RunSubmitToolOutputsParamsStream,
options?: Core.RequestOptions,
): Promise<Run> {
return await this._createToolAssistantStream(runs, threadId, runId, params, options);
}
}
function assertNever(_x: never) {}