UNPKG

@assistant-ui/react

Version:

TypeScript/React library for AI Chat

295 lines (269 loc) 9.1 kB
import { useEffect, useRef, useState } from "react"; import { createAssistantStreamController, type ToolCallStreamController, ToolResponse, unstable_toolResultStream, type Tool, } from "assistant-stream"; import type { AssistantTransportCommand, AssistantTransportState, } from "./types"; import { AssistantMetaTransformStream, type ReadonlyJSONValue, } from "assistant-stream/utils"; const isArgsTextComplete = (argsText: string) => { try { JSON.parse(argsText); return true; } catch { return false; } }; type UseToolInvocationsParams = { state: AssistantTransportState; getTools: () => Record<string, Tool> | undefined; onResult: (command: AssistantTransportCommand) => void; setToolStatuses: ( updater: | Record<string, ToolExecutionStatus> | (( prev: Record<string, ToolExecutionStatus>, ) => Record<string, ToolExecutionStatus>), ) => void; }; export type ToolExecutionStatus = | { type: "executing" } | { type: "interrupt"; payload: { type: "human"; payload: unknown } }; export function useToolInvocations({ state, getTools, onResult, setToolStatuses, }: UseToolInvocationsParams) { const lastToolStates = useRef< Record< string, { argsText: string; hasResult: boolean; argsComplete: boolean; controller: ToolCallStreamController; } > >({}); const humanInputRef = useRef< Map< string, { resolve: (payload: unknown) => void; reject: (reason: unknown) => void; } > >(new Map()); const acRef = useRef<AbortController>(new AbortController()); const executingCountRef = useRef(0); const settledResolversRef = useRef<Array<() => void>>([]); const [controller] = useState(() => { const [stream, controller] = createAssistantStreamController(); const transform = unstable_toolResultStream( getTools, () => acRef.current?.signal ?? new AbortController().signal, (toolCallId: string, payload: unknown) => { return new Promise<unknown>((resolve, reject) => { // Reject previous human input request if it exists const previous = humanInputRef.current.get(toolCallId); if (previous) { previous.reject( new Error("Human input request was superseded by a new request"), ); } humanInputRef.current.set(toolCallId, { resolve, reject }); setToolStatuses((prev) => ({ ...prev, [toolCallId]: { type: "interrupt", payload: { type: "human", payload }, }, })); }); }, { onExecutionStart: (toolCallId: string) => { executingCountRef.current++; setToolStatuses((prev) => ({ ...prev, [toolCallId]: { type: "executing" }, })); }, onExecutionEnd: (toolCallId: string) => { executingCountRef.current--; setToolStatuses((prev) => { const next = { ...prev }; delete next[toolCallId]; return next; }); // Resolve any waiting abort promises when all tools have settled if (executingCountRef.current === 0) { settledResolversRef.current.forEach((resolve) => resolve()); settledResolversRef.current = []; } }, }, ); stream .pipeThrough(transform) .pipeThrough(new AssistantMetaTransformStream()) .pipeTo( new WritableStream({ write(chunk) { if (chunk.type === "result") { // the tool call result was already set by the backend if (lastToolStates.current[chunk.meta.toolCallId]?.hasResult) return; onResult({ type: "add-tool-result", toolCallId: chunk.meta.toolCallId, toolName: chunk.meta.toolName, result: chunk.result, isError: chunk.isError, ...(chunk.artifact && { artifact: chunk.artifact }), }); } }, }), ); return controller; }); const ignoredToolIds = useRef<Set<string>>(new Set()); const isInitialState = useRef(true); useEffect(() => { const processMessages = ( messages: readonly (typeof state.messages)[number][], ) => { messages.forEach((message) => { message.content.forEach((content) => { if (content.type === "tool-call") { if (isInitialState.current) { ignoredToolIds.current.add(content.toolCallId); } else { if (ignoredToolIds.current.has(content.toolCallId)) { return; } let lastState = lastToolStates.current[content.toolCallId]; if (!lastState) { const toolCallController = controller.addToolCallPart({ toolName: content.toolName, toolCallId: content.toolCallId, }); lastState = { argsText: "", hasResult: false, argsComplete: false, controller: toolCallController, }; lastToolStates.current[content.toolCallId] = lastState; } if (content.argsText !== lastState.argsText) { if (lastState.argsComplete) { if (process.env["NODE_ENV"] !== "production") { console.warn( "argsText updated after controller was closed:", { previous: lastState.argsText, next: content.argsText, }, ); } } else { if (!content.argsText.startsWith(lastState.argsText)) { throw new Error( `Tool call argsText can only be appended, not updated: ${content.argsText} does not start with ${lastState.argsText}`, ); } const argsTextDelta = content.argsText.slice( lastState.argsText.length, ); lastState.controller.argsText.append(argsTextDelta); const shouldClose = isArgsTextComplete(content.argsText); if (shouldClose) { lastState.controller.argsText.close(); } lastToolStates.current[content.toolCallId] = { argsText: content.argsText, hasResult: lastState.hasResult, argsComplete: shouldClose, controller: lastState.controller, }; } } if (content.result !== undefined && !lastState.hasResult) { lastState.controller.setResponse( new ToolResponse({ result: content.result as ReadonlyJSONValue, artifact: content.artifact as ReadonlyJSONValue | undefined, isError: content.isError, }), ); lastState.controller.close(); lastToolStates.current[content.toolCallId] = { hasResult: true, argsComplete: true, argsText: lastState.argsText, controller: lastState.controller, }; } } // Recursively process nested messages if (content.messages) { processMessages(content.messages); } } }); }); }; processMessages(state.messages); if (isInitialState.current) { isInitialState.current = false; } }, [state, controller, onResult]); const abort = (): Promise<void> => { humanInputRef.current.forEach(({ reject }) => { reject(new Error("Tool execution aborted")); }); humanInputRef.current.clear(); acRef.current.abort(); acRef.current = new AbortController(); // Return a promise that resolves when all executing tools have settled if (executingCountRef.current === 0) { return Promise.resolve(); } return new Promise<void>((resolve) => { settledResolversRef.current.push(resolve); }); }; return { reset: () => { void abort(); isInitialState.current = true; }, abort, resume: (toolCallId: string, payload: unknown) => { const handlers = humanInputRef.current.get(toolCallId); if (handlers) { humanInputRef.current.delete(toolCallId); setToolStatuses((prev) => ({ ...prev, [toolCallId]: { type: "executing" }, })); handlers.resolve(payload); } else { throw new Error( `Tool call ${toolCallId} is not waiting for human input`, ); } }, }; }