ai
Version:
AI SDK by Vercel - The AI Toolkit for TypeScript and JavaScript
737 lines (640 loc) • 20.9 kB
text/typescript
import {
FlexibleSchema,
generateId as generateIdFunc,
IdGenerator,
InferSchema,
} from '@ai-sdk/provider-utils';
import { FinishReason } from '../types/language-model';
import { UIMessageChunk } from '../ui-message-stream/ui-message-chunks';
import { consumeStream } from '../util/consume-stream';
import { SerialJobExecutor } from '../util/serial-job-executor';
import { ChatTransport } from './chat-transport';
import { convertFileListToFileUIParts } from './convert-file-list-to-file-ui-parts';
import { DefaultChatTransport } from './default-chat-transport';
import {
createStreamingUIMessageState,
processUIMessageStream,
StreamingUIMessageState,
} from './process-ui-message-stream';
import {
InferUIMessageToolCall,
isToolUIPart,
UIMessagePart,
UITools,
type DataUIPart,
type FileUIPart,
type InferUIMessageData,
type InferUIMessageMetadata,
type InferUIMessageTools,
type UIDataTypes,
type UIMessage,
} from './ui-messages';
export type CreateUIMessage<UI_MESSAGE extends UIMessage> = Omit<
UI_MESSAGE,
'id' | 'role'
> & {
id?: UI_MESSAGE['id'];
role?: UI_MESSAGE['role'];
};
export type UIDataPartSchemas = Record<string, FlexibleSchema>;
export type UIDataTypesToSchemas<T extends UIDataTypes> = {
[K in keyof T]: FlexibleSchema<T[K]>;
};
export type InferUIDataParts<T extends UIDataPartSchemas> = {
[K in keyof T]: InferSchema<T[K]>;
};
export type ChatRequestOptions = {
/**
* Additional headers that should be to be passed to the API endpoint.
*/
headers?: Record<string, string> | Headers;
/**
* Additional body JSON properties that should be sent to the API endpoint.
*/
body?: object; // TODO JSONStringifyable
metadata?: unknown;
};
/**
* Function that can be called to add a tool approval response to the chat.
*/
export type ChatAddToolApproveResponseFunction = ({
id,
approved,
reason,
}: {
id: string;
/**
* Flag indicating whether the approval was granted or denied.
*/
approved: boolean;
/**
* Optional reason for the approval or denial.
*/
reason?: string;
}) => void | PromiseLike<void>;
export type ChatStatus = 'submitted' | 'streaming' | 'ready' | 'error';
type ActiveResponse<UI_MESSAGE extends UIMessage> = {
state: StreamingUIMessageState<UI_MESSAGE>;
abortController: AbortController;
};
export interface ChatState<UI_MESSAGE extends UIMessage> {
status: ChatStatus;
error: Error | undefined;
messages: UI_MESSAGE[];
pushMessage: (message: UI_MESSAGE) => void;
popMessage: () => void;
replaceMessage: (index: number, message: UI_MESSAGE) => void;
snapshot: <T>(thing: T) => T;
}
export type ChatOnErrorCallback = (error: Error) => void;
export type ChatOnToolCallCallback<UI_MESSAGE extends UIMessage = UIMessage> =
(options: {
toolCall: InferUIMessageToolCall<UI_MESSAGE>;
}) => void | PromiseLike<void>;
export type ChatOnDataCallback<UI_MESSAGE extends UIMessage> = (
dataPart: DataUIPart<InferUIMessageData<UI_MESSAGE>>,
) => void;
/**
* Function that is called when the assistant response has finished streaming.
*
* @param message The assistant message that was streamed.
* @param messages The full chat history, including the assistant message.
*
* @param isAbort Indicates whether the request has been aborted.
* @param isDisconnect Indicates whether the request has been ended by a network error.
* @param isError Indicates whether the request has been ended by an error.
* @param finishReason The reason why the generation finished.
*/
export type ChatOnFinishCallback<UI_MESSAGE extends UIMessage> = (options: {
message: UI_MESSAGE;
messages: UI_MESSAGE[];
isAbort: boolean;
isDisconnect: boolean;
isError: boolean;
finishReason?: FinishReason;
}) => void;
export interface ChatInit<UI_MESSAGE extends UIMessage> {
/**
* A unique identifier for the chat. If not provided, a random one will be
* generated.
*/
id?: string;
messageMetadataSchema?: FlexibleSchema<InferUIMessageMetadata<UI_MESSAGE>>;
dataPartSchemas?: UIDataTypesToSchemas<InferUIMessageData<UI_MESSAGE>>;
messages?: UI_MESSAGE[];
/**
* A way to provide a function that is going to be used for ids for messages and the chat.
* If not provided the default AI SDK `generateId` is used.
*/
generateId?: IdGenerator;
transport?: ChatTransport<UI_MESSAGE>;
/**
* Callback function to be called when an error is encountered.
*/
onError?: ChatOnErrorCallback;
/**
* Optional callback function that is invoked when a tool call is received.
* Intended for automatic client-side tool execution.
*
* You can optionally return a result for the tool call,
* either synchronously or asynchronously.
*/
onToolCall?: ChatOnToolCallCallback<UI_MESSAGE>;
/**
* Function that is called when the assistant response has finished streaming.
*/
onFinish?: ChatOnFinishCallback<UI_MESSAGE>;
/**
* Optional callback function that is called when a data part is received.
*
* @param data The data part that was received.
*/
onData?: ChatOnDataCallback<UI_MESSAGE>;
/**
* When provided, this function will be called when the stream is finished or a tool call is added
* to determine if the current messages should be resubmitted.
*/
sendAutomaticallyWhen?: (options: {
messages: UI_MESSAGE[];
}) => boolean | PromiseLike<boolean>;
}
export abstract class AbstractChat<UI_MESSAGE extends UIMessage> {
readonly id: string;
readonly generateId: IdGenerator;
protected state: ChatState<UI_MESSAGE>;
private messageMetadataSchema:
| FlexibleSchema<InferUIMessageMetadata<UI_MESSAGE>>
| undefined;
private dataPartSchemas:
| UIDataTypesToSchemas<InferUIMessageData<UI_MESSAGE>>
| undefined;
private readonly transport: ChatTransport<UI_MESSAGE>;
private onError?: ChatInit<UI_MESSAGE>['onError'];
private onToolCall?: ChatInit<UI_MESSAGE>['onToolCall'];
private onFinish?: ChatInit<UI_MESSAGE>['onFinish'];
private onData?: ChatInit<UI_MESSAGE>['onData'];
private sendAutomaticallyWhen?: ChatInit<UI_MESSAGE>['sendAutomaticallyWhen'];
private activeResponse: ActiveResponse<UI_MESSAGE> | undefined = undefined;
private jobExecutor = new SerialJobExecutor();
constructor({
generateId = generateIdFunc,
id = generateId(),
transport = new DefaultChatTransport(),
messageMetadataSchema,
dataPartSchemas,
state,
onError,
onToolCall,
onFinish,
onData,
sendAutomaticallyWhen,
}: Omit<ChatInit<UI_MESSAGE>, 'messages'> & {
state: ChatState<UI_MESSAGE>;
}) {
this.id = id;
this.transport = transport;
this.generateId = generateId;
this.messageMetadataSchema = messageMetadataSchema;
this.dataPartSchemas = dataPartSchemas;
this.state = state;
this.onError = onError;
this.onToolCall = onToolCall;
this.onFinish = onFinish;
this.onData = onData;
this.sendAutomaticallyWhen = sendAutomaticallyWhen;
}
/**
* Hook status:
*
* - `submitted`: The message has been sent to the API and we're awaiting the start of the response stream.
* - `streaming`: The response is actively streaming in from the API, receiving chunks of data.
* - `ready`: The full response has been received and processed; a new user message can be submitted.
* - `error`: An error occurred during the API request, preventing successful completion.
*/
get status(): ChatStatus {
return this.state.status;
}
protected setStatus({
status,
error,
}: {
status: ChatStatus;
error?: Error;
}) {
if (this.status === status) return;
this.state.status = status;
this.state.error = error;
}
get error() {
return this.state.error;
}
get messages(): UI_MESSAGE[] {
return this.state.messages;
}
get lastMessage(): UI_MESSAGE | undefined {
return this.state.messages[this.state.messages.length - 1];
}
set messages(messages: UI_MESSAGE[]) {
this.state.messages = messages;
}
/**
* Appends or replaces a user message to the chat list. This triggers the API call to fetch
* the assistant's response.
*
* If a messageId is provided, the message will be replaced.
*/
sendMessage = async (
message?:
| (CreateUIMessage<UI_MESSAGE> & {
text?: never;
files?: never;
messageId?: string;
})
| {
text: string;
files?: FileList | FileUIPart[];
metadata?: InferUIMessageMetadata<UI_MESSAGE>;
parts?: never;
messageId?: string;
}
| {
files: FileList | FileUIPart[];
metadata?: InferUIMessageMetadata<UI_MESSAGE>;
parts?: never;
messageId?: string;
},
options?: ChatRequestOptions,
): Promise<void> => {
if (message == null) {
await this.makeRequest({
trigger: 'submit-message',
messageId: this.lastMessage?.id,
...options,
});
return;
}
let uiMessage: CreateUIMessage<UI_MESSAGE>;
if ('text' in message || 'files' in message) {
const fileParts = Array.isArray(message.files)
? message.files
: await convertFileListToFileUIParts(message.files);
uiMessage = {
parts: [
...fileParts,
...('text' in message && message.text != null
? [{ type: 'text' as const, text: message.text }]
: []),
],
} as UI_MESSAGE;
} else {
uiMessage = message;
}
if (message.messageId != null) {
const messageIndex = this.state.messages.findIndex(
m => m.id === message.messageId,
);
if (messageIndex === -1) {
throw new Error(`message with id ${message.messageId} not found`);
}
if (this.state.messages[messageIndex].role !== 'user') {
throw new Error(
`message with id ${message.messageId} is not a user message`,
);
}
// remove all messages after the message with the given id
this.state.messages = this.state.messages.slice(0, messageIndex + 1);
// update the message with the new content
this.state.replaceMessage(messageIndex, {
...uiMessage,
id: message.messageId,
role: uiMessage.role ?? 'user',
metadata: message.metadata,
} as UI_MESSAGE);
} else {
this.state.pushMessage({
...uiMessage,
id: uiMessage.id ?? this.generateId(),
role: uiMessage.role ?? 'user',
metadata: message.metadata,
} as UI_MESSAGE);
}
await this.makeRequest({
trigger: 'submit-message',
messageId: message.messageId,
...options,
});
};
/**
* Regenerate the assistant message with the provided message id.
* If no message id is provided, the last assistant message will be regenerated.
*/
regenerate = async ({
messageId,
...options
}: {
messageId?: string;
} & ChatRequestOptions = {}): Promise<void> => {
const messageIndex =
messageId == null
? this.state.messages.length - 1
: this.state.messages.findIndex(message => message.id === messageId);
if (messageIndex === -1) {
throw new Error(`message ${messageId} not found`);
}
// set the messages to the message before the assistant message
this.state.messages = this.state.messages.slice(
0,
// if the message is a user message, we need to include it in the request:
this.messages[messageIndex].role === 'assistant'
? messageIndex
: messageIndex + 1,
);
await this.makeRequest({
trigger: 'regenerate-message',
messageId,
...options,
});
};
/**
* Attempt to resume an ongoing streaming response.
*/
resumeStream = async (options: ChatRequestOptions = {}): Promise<void> => {
await this.makeRequest({ trigger: 'resume-stream', ...options });
};
/**
* Clear the error state and set the status to ready if the chat is in an error state.
*/
clearError = () => {
if (this.status === 'error') {
this.state.error = undefined;
this.setStatus({ status: 'ready' });
}
};
addToolApprovalResponse: ChatAddToolApproveResponseFunction = async ({
id,
approved,
reason,
}) =>
this.jobExecutor.run(async () => {
const messages = this.state.messages;
const lastMessage = messages[messages.length - 1];
const updatePart = (
part: UIMessagePart<UIDataTypes, UITools>,
): UIMessagePart<UIDataTypes, UITools> =>
isToolUIPart(part) &&
part.state === 'approval-requested' &&
part.approval.id === id
? {
...part,
state: 'approval-responded',
approval: { id, approved, reason },
}
: part;
// update the message to trigger an immediate UI update
this.state.replaceMessage(messages.length - 1, {
...lastMessage,
parts: lastMessage.parts.map(updatePart),
});
// update the active response if it exists
if (this.activeResponse) {
this.activeResponse.state.message.parts =
this.activeResponse.state.message.parts.map(updatePart);
}
// automatically send the message if the sendAutomaticallyWhen function returns true
if (
this.status !== 'streaming' &&
this.status !== 'submitted' &&
this.sendAutomaticallyWhen
) {
this.shouldSendAutomatically().then(shouldSend => {
if (shouldSend) {
// no await to avoid deadlocking
this.makeRequest({
trigger: 'submit-message',
messageId: this.lastMessage?.id,
});
}
});
}
});
addToolOutput = async <TOOL extends keyof InferUIMessageTools<UI_MESSAGE>>({
state = 'output-available',
tool,
toolCallId,
output,
errorText,
}:
| {
state?: 'output-available';
tool: TOOL;
toolCallId: string;
output: InferUIMessageTools<UI_MESSAGE>[TOOL]['output'];
errorText?: never;
}
| {
state: 'output-error';
tool: TOOL;
toolCallId: string;
output?: never;
errorText: string;
}) =>
this.jobExecutor.run(async () => {
const messages = this.state.messages;
const lastMessage = messages[messages.length - 1];
const updatePart = (
part: UIMessagePart<UIDataTypes, UITools>,
): UIMessagePart<UIDataTypes, UITools> =>
isToolUIPart(part) && part.toolCallId === toolCallId
? ({ ...part, state, output, errorText } as typeof part)
: part;
// update the message to trigger an immediate UI update
this.state.replaceMessage(messages.length - 1, {
...lastMessage,
parts: lastMessage.parts.map(updatePart),
});
// update the active response if it exists
if (this.activeResponse) {
this.activeResponse.state.message.parts =
this.activeResponse.state.message.parts.map(updatePart);
}
// automatically send the message if the sendAutomaticallyWhen function returns true
if (
this.status !== 'streaming' &&
this.status !== 'submitted' &&
this.sendAutomaticallyWhen
) {
this.shouldSendAutomatically().then(shouldSend => {
if (shouldSend) {
// no await to avoid deadlocking
this.makeRequest({
trigger: 'submit-message',
messageId: this.lastMessage?.id,
});
}
});
}
});
/** @deprecated Use addToolOutput */
addToolResult = this.addToolOutput;
/**
* Abort the current request immediately, keep the generated tokens if any.
*/
stop = async () => {
if (this.status !== 'streaming' && this.status !== 'submitted') return;
if (this.activeResponse?.abortController) {
this.activeResponse.abortController.abort();
}
};
private async shouldSendAutomatically(): Promise<boolean> {
if (!this.sendAutomaticallyWhen) return false;
const result = this.sendAutomaticallyWhen({
messages: this.state.messages,
});
// Check if result is a promise
if (result && typeof result === 'object' && 'then' in result) {
return await result;
}
return result as boolean;
}
private async makeRequest({
trigger,
metadata,
headers,
body,
messageId,
}: {
trigger: 'submit-message' | 'resume-stream' | 'regenerate-message';
messageId?: string;
} & ChatRequestOptions) {
this.setStatus({ status: 'submitted', error: undefined });
const lastMessage = this.lastMessage;
let isAbort = false;
let isDisconnect = false;
let isError = false;
try {
const activeResponse = {
state: createStreamingUIMessageState({
lastMessage: this.state.snapshot(lastMessage),
messageId: this.generateId(),
}),
abortController: new AbortController(),
} as ActiveResponse<UI_MESSAGE>;
activeResponse.abortController.signal.addEventListener('abort', () => {
isAbort = true;
});
this.activeResponse = activeResponse;
let stream: ReadableStream<UIMessageChunk>;
if (trigger === 'resume-stream') {
const reconnect = await this.transport.reconnectToStream({
chatId: this.id,
metadata,
headers,
body,
});
if (reconnect == null) {
this.setStatus({ status: 'ready' });
return; // no active stream found, so we do not resume
}
stream = reconnect;
} else {
stream = await this.transport.sendMessages({
chatId: this.id,
messages: this.state.messages,
abortSignal: activeResponse.abortController.signal,
metadata,
headers,
body,
trigger,
messageId,
});
}
const runUpdateMessageJob = (
job: (options: {
state: StreamingUIMessageState<UI_MESSAGE>;
write: () => void;
}) => Promise<void>,
) =>
// serialize the job execution to avoid race conditions:
this.jobExecutor.run(() =>
job({
state: activeResponse.state,
write: () => {
// streaming is set on first write (before it should be "submitted")
this.setStatus({ status: 'streaming' });
const replaceLastMessage =
activeResponse.state.message.id === this.lastMessage?.id;
if (replaceLastMessage) {
this.state.replaceMessage(
this.state.messages.length - 1,
activeResponse.state.message,
);
} else {
this.state.pushMessage(activeResponse.state.message);
}
},
}),
);
await consumeStream({
stream: processUIMessageStream({
stream,
onToolCall: this.onToolCall,
onData: this.onData,
messageMetadataSchema: this.messageMetadataSchema,
dataPartSchemas: this.dataPartSchemas,
runUpdateMessageJob,
onError: error => {
throw error;
},
}),
onError: error => {
throw error;
},
});
this.setStatus({ status: 'ready' });
} catch (err) {
// Ignore abort errors as they are expected.
if (isAbort || (err as any).name === 'AbortError') {
isAbort = true;
this.setStatus({ status: 'ready' });
return null;
}
isError = true;
// Network errors such as disconnected, timeout, etc.
if (
err instanceof TypeError &&
(err.message.toLowerCase().includes('fetch') ||
err.message.toLowerCase().includes('network'))
) {
isDisconnect = true;
}
if (this.onError && err instanceof Error) {
this.onError(err);
}
this.setStatus({ status: 'error', error: err as Error });
} finally {
try {
this.onFinish?.({
message: this.activeResponse!.state.message,
messages: this.state.messages,
isAbort,
isDisconnect,
isError,
finishReason: this.activeResponse?.state.finishReason,
});
} catch (err) {
console.error(err);
}
this.activeResponse = undefined;
}
// automatically send the message if the sendAutomaticallyWhen function returns true
if (!isError && (await this.shouldSendAutomatically())) {
await this.makeRequest({
trigger: 'submit-message',
messageId: this.lastMessage?.id,
metadata,
headers,
body,
});
}
}
}