openai
Version:
The official TypeScript library for the OpenAI API
330 lines (277 loc) • 9.08 kB
text/typescript
import { ReadableStream, type Response } from './_shims/index';
import { OpenAIError } from './error';
import { findDoubleNewlineIndex, LineDecoder } from './internal/decoders/line';
import { ReadableStreamToAsyncIterable } from './internal/stream-utils';
import { createResponseHeaders } from './core';
import { APIError } from './error';
type Bytes = string | ArrayBuffer | Uint8Array | Buffer | null | undefined;
export type ServerSentEvent = {
event: string | null;
data: string;
raw: string[];
};
export class Stream<Item> implements AsyncIterable<Item> {
controller: AbortController;
constructor(
private iterator: () => AsyncIterator<Item>,
controller: AbortController,
) {
this.controller = controller;
}
static fromSSEResponse<Item>(response: Response, controller: AbortController): Stream<Item> {
let consumed = false;
async function* iterator(): AsyncIterator<Item, any, undefined> {
if (consumed) {
throw new Error('Cannot iterate over a consumed stream, use `.tee()` to split the stream.');
}
consumed = true;
let done = false;
try {
for await (const sse of _iterSSEMessages(response, controller)) {
if (done) continue;
if (sse.data.startsWith('[DONE]')) {
done = true;
continue;
}
if (
sse.event === null ||
sse.event.startsWith('response.') ||
sse.event.startsWith('transcript.')
) {
let data;
try {
data = JSON.parse(sse.data);
} catch (e) {
console.error(`Could not parse message into JSON:`, sse.data);
console.error(`From chunk:`, sse.raw);
throw e;
}
if (data && data.error) {
throw new APIError(undefined, data.error, undefined, createResponseHeaders(response.headers));
}
yield data;
} else {
let data;
try {
data = JSON.parse(sse.data);
} catch (e) {
console.error(`Could not parse message into JSON:`, sse.data);
console.error(`From chunk:`, sse.raw);
throw e;
}
// TODO: Is this where the error should be thrown?
if (sse.event == 'error') {
throw new APIError(undefined, data.error, data.message, undefined);
}
yield { event: sse.event, data: data } as any;
}
}
done = true;
} catch (e) {
// If the user calls `stream.controller.abort()`, we should exit without throwing.
if (e instanceof Error && e.name === 'AbortError') return;
throw e;
} finally {
// If the user `break`s, abort the ongoing request.
if (!done) controller.abort();
}
}
return new Stream(iterator, controller);
}
/**
* Generates a Stream from a newline-separated ReadableStream
* where each item is a JSON value.
*/
static fromReadableStream<Item>(readableStream: ReadableStream, controller: AbortController): Stream<Item> {
let consumed = false;
async function* iterLines(): AsyncGenerator<string, void, unknown> {
const lineDecoder = new LineDecoder();
const iter = ReadableStreamToAsyncIterable<Bytes>(readableStream);
for await (const chunk of iter) {
for (const line of lineDecoder.decode(chunk)) {
yield line;
}
}
for (const line of lineDecoder.flush()) {
yield line;
}
}
async function* iterator(): AsyncIterator<Item, any, undefined> {
if (consumed) {
throw new Error('Cannot iterate over a consumed stream, use `.tee()` to split the stream.');
}
consumed = true;
let done = false;
try {
for await (const line of iterLines()) {
if (done) continue;
if (line) yield JSON.parse(line);
}
done = true;
} catch (e) {
// If the user calls `stream.controller.abort()`, we should exit without throwing.
if (e instanceof Error && e.name === 'AbortError') return;
throw e;
} finally {
// If the user `break`s, abort the ongoing request.
if (!done) controller.abort();
}
}
return new Stream(iterator, controller);
}
[Symbol.asyncIterator](): AsyncIterator<Item> {
return this.iterator();
}
/**
* Splits the stream into two streams which can be
* independently read from at different speeds.
*/
tee(): [Stream<Item>, Stream<Item>] {
const left: Array<Promise<IteratorResult<Item>>> = [];
const right: Array<Promise<IteratorResult<Item>>> = [];
const iterator = this.iterator();
const teeIterator = (queue: Array<Promise<IteratorResult<Item>>>): AsyncIterator<Item> => {
return {
next: () => {
if (queue.length === 0) {
const result = iterator.next();
left.push(result);
right.push(result);
}
return queue.shift()!;
},
};
};
return [
new Stream(() => teeIterator(left), this.controller),
new Stream(() => teeIterator(right), this.controller),
];
}
/**
* Converts this stream to a newline-separated ReadableStream of
* JSON stringified values in the stream
* which can be turned back into a Stream with `Stream.fromReadableStream()`.
*/
toReadableStream(): ReadableStream {
const self = this;
let iter: AsyncIterator<Item>;
const encoder = new TextEncoder();
return new ReadableStream({
async start() {
iter = self[Symbol.asyncIterator]();
},
async pull(ctrl: any) {
try {
const { value, done } = await iter.next();
if (done) return ctrl.close();
const bytes = encoder.encode(JSON.stringify(value) + '\n');
ctrl.enqueue(bytes);
} catch (err) {
ctrl.error(err);
}
},
async cancel() {
await iter.return?.();
},
});
}
}
export async function* _iterSSEMessages(
response: Response,
controller: AbortController,
): AsyncGenerator<ServerSentEvent, void, unknown> {
if (!response.body) {
controller.abort();
throw new OpenAIError(`Attempted to iterate over a response with no body`);
}
const sseDecoder = new SSEDecoder();
const lineDecoder = new LineDecoder();
const iter = ReadableStreamToAsyncIterable<Bytes>(response.body);
for await (const sseChunk of iterSSEChunks(iter)) {
for (const line of lineDecoder.decode(sseChunk)) {
const sse = sseDecoder.decode(line);
if (sse) yield sse;
}
}
for (const line of lineDecoder.flush()) {
const sse = sseDecoder.decode(line);
if (sse) yield sse;
}
}
/**
* Given an async iterable iterator, iterates over it and yields full
* SSE chunks, i.e. yields when a double new-line is encountered.
*/
async function* iterSSEChunks(iterator: AsyncIterableIterator<Bytes>): AsyncGenerator<Uint8Array> {
let data = new Uint8Array();
for await (const chunk of iterator) {
if (chunk == null) {
continue;
}
const binaryChunk =
chunk instanceof ArrayBuffer ? new Uint8Array(chunk)
: typeof chunk === 'string' ? new TextEncoder().encode(chunk)
: chunk;
let newData = new Uint8Array(data.length + binaryChunk.length);
newData.set(data);
newData.set(binaryChunk, data.length);
data = newData;
let patternIndex;
while ((patternIndex = findDoubleNewlineIndex(data)) !== -1) {
yield data.slice(0, patternIndex);
data = data.slice(patternIndex);
}
}
if (data.length > 0) {
yield data;
}
}
class SSEDecoder {
private data: string[];
private event: string | null;
private chunks: string[];
constructor() {
this.event = null;
this.data = [];
this.chunks = [];
}
decode(line: string) {
if (line.endsWith('\r')) {
line = line.substring(0, line.length - 1);
}
if (!line) {
// empty line and we didn't previously encounter any messages
if (!this.event && !this.data.length) return null;
const sse: ServerSentEvent = {
event: this.event,
data: this.data.join('\n'),
raw: this.chunks,
};
this.event = null;
this.data = [];
this.chunks = [];
return sse;
}
this.chunks.push(line);
if (line.startsWith(':')) {
return null;
}
let [fieldname, _, value] = partition(line, ':');
if (value.startsWith(' ')) {
value = value.substring(1);
}
if (fieldname === 'event') {
this.event = value;
} else if (fieldname === 'data') {
this.data.push(value);
}
return null;
}
}
function partition(str: string, delimiter: string): [string, string, string] {
const index = str.indexOf(delimiter);
if (index !== -1) {
return [str.substring(0, index), delimiter, str.substring(index + delimiter.length)];
}
return [str, '', ''];
}