@assistant-ui/react
Version:
Typescript/React library for AI Chat
189 lines (165 loc) • 5.67 kB
text/typescript
import {
ChatModelAdapter,
ChatModelRunOptions,
} from "../local/ChatModelAdapter";
import { toCoreMessages } from "./converters/toCoreMessages";
import { toLanguageModelTools } from "./converters/toLanguageModelTools";
import { EdgeRuntimeRequestOptions } from "./EdgeRuntimeRequestOptions";
import { toolResultStream } from "./streams/toolResultStream";
import { toLanguageModelMessages } from "./converters";
import { ThreadMessage } from "../../types";
import { Tool } from "../../model-context";
import { z } from "zod";
import zodToJsonSchema from "zod-to-json-schema";
import { JSONSchema7 } from "json-schema";
import {
AssistantMessageAccumulator,
DataStreamDecoder,
} from "assistant-stream";
export function asAsyncIterable<T>(
source: ReadableStream<T>,
): AsyncIterable<T> {
return {
[Symbol.asyncIterator]: () => {
const reader = source.getReader();
return {
async next(): Promise<IteratorResult<T, undefined>> {
const { done, value } = await reader.read();
return done
? { done: true, value: undefined }
: { done: false, value };
},
};
},
};
}
type HeadersValue = Record<string, string> | Headers;
export type EdgeChatAdapterOptions = {
api: string;
// experimental_prepareRequestBody?: (options: {
// id: string;
// messages: UIMessage[];
// requestData?: JSONValue;
// requestBody?: object;
// }) => unknown;
// onToolCall?: ({
// toolCall,
// }: {
// toolCall: UIMessageToolInvocation;
// }) => void | Promise<unknown> | unknown;
/**
* Callback function to be called when the API response is received.
*/
onResponse?: (response: Response) => void | Promise<void>;
/**
* Optional callback function that is called when the assistant message is finished streaming.
*/
onFinish?: (message: ThreadMessage) => void;
/**
* Callback function to be called when an error is encountered.
*/
onError?: (error: Error) => void;
credentials?: RequestCredentials;
/**
* Headers to be sent with the request.
* Can be a static headers object or a function that returns a Promise of headers.
*/
headers?: HeadersValue | (() => Promise<HeadersValue>);
body?: object;
/**
* @deprecated Renamed to `sendExtraMessageFields`.
*/
unstable_sendMessageIds?: boolean;
/**
* When enabled, the adapter will not strip `id` from messages in the messages array.
*/
sendExtraMessageFields?: boolean;
/**
* When enabled, the adapter will send messages in the format expected by the Vercel AI SDK Core.
* This feature will be removed in the future in favor of a better solution.
*
* `v2` sends frontend tools in a format that can be directly passed to `stremaText`
*/
unstable_AISDKInterop?: boolean | "v2" | undefined;
};
const toAISDKTools = (tools: Record<string, Tool<any, any>>) => {
return Object.fromEntries(
Object.entries(tools).map(([name, tool]) => [
name,
{
...(tool.description ? { description: tool.description } : undefined),
parameters: (tool.parameters instanceof z.ZodType
? zodToJsonSchema(tool.parameters)
: tool.parameters) as JSONSchema7,
},
]),
);
};
export class EdgeChatAdapter implements ChatModelAdapter {
constructor(private options: EdgeChatAdapterOptions) {}
async *run({
messages,
runConfig,
abortSignal,
context,
unstable_assistantMessageId,
unstable_getMessage,
}: ChatModelRunOptions) {
const headersValue =
typeof this.options.headers === "function"
? await this.options.headers()
: this.options.headers;
const headers = new Headers(headersValue);
headers.set("Content-Type", "application/json");
const result = await fetch(this.options.api, {
method: "POST",
headers,
credentials: this.options.credentials ?? "same-origin",
body: JSON.stringify({
system: context.system,
messages: this.options.unstable_AISDKInterop
? (toLanguageModelMessages(messages, {
unstable_includeId:
this.options.unstable_sendMessageIds ||
this.options.sendExtraMessageFields,
}) as EdgeRuntimeRequestOptions["messages"]) // TODO figure out a better way to do this
: toCoreMessages(messages, {
unstable_includeId:
this.options.unstable_sendMessageIds ||
this.options.sendExtraMessageFields,
}),
tools: context.tools
? this.options.unstable_AISDKInterop === "v2"
? (toAISDKTools(context.tools) as any)
: toLanguageModelTools(context.tools)
: [],
unstable_assistantMessageId,
runConfig,
...context.callSettings,
...context.config,
...this.options.body,
} satisfies EdgeRuntimeRequestOptions),
signal: abortSignal,
});
await this.options.onResponse?.(result);
try {
if (!result.ok) {
throw new Error(`Status ${result.status}: ${await result.text()}`);
}
if (!result.body) {
throw new Error("Response body is null");
}
const stream = result.body
.pipeThrough(new DataStreamDecoder())
.pipeThrough(toolResultStream(context.tools, abortSignal))
.pipeThrough(new AssistantMessageAccumulator());
for await (const update of asAsyncIterable(stream)) {
yield update;
}
this.options.onFinish?.(unstable_getMessage());
} catch (error: unknown) {
this.options.onError?.(error as Error);
throw error;
}
}
}