UNPKG

@assistant-ui/react

Version:

TypeScript/React library for AI Chat

200 lines 9.13 kB
"use client"; import { jsx as _jsx, jsxs as _jsxs, Fragment as _Fragment } from "react/jsx-runtime"; import { memo, useMemo, } from "react"; import { useAssistantState, PartByIndexProvider, useAssistantApi, TextMessagePartProvider, } from "../../context/index.js"; import { MessagePartPrimitiveText } from "../messagePart/MessagePartText.js"; import { MessagePartPrimitiveImage } from "../messagePart/MessagePartImage.js"; import { MessagePartPrimitiveInProgress } from "../messagePart/MessagePartInProgress.js"; /** * Groups message parts by their parent ID. * Parts without a parent ID appear in their chronological position as individual groups. * Parts with the same parent ID are grouped together at the position of their first occurrence. */ const groupMessagePartsByParentId = (parts) => { // Map maintains insertion order, so groups appear in order of first occurrence const groupMap = new Map(); // Process each part in order for (let i = 0; i < parts.length; i++) { const part = parts[i]; const parentId = part?.parentId; // For parts without parentId, assign a unique group ID to maintain their position const groupId = parentId ?? `__ungrouped_${i}`; // Get or create the indices array for this group const indices = groupMap.get(groupId) ?? []; indices.push(i); groupMap.set(groupId, indices); } // Convert map to array of groups const groups = []; for (const [groupId, indices] of groupMap) { // Extract parentId (undefined for ungrouped parts) const groupKey = groupId.startsWith("__ungrouped_") ? undefined : groupId; groups.push({ groupKey, indices }); } return groups; }; const useMessagePartsGrouped = (groupingFunction) => { const parts = useAssistantState(({ message }) => message.parts); return useMemo(() => { if (parts.length === 0) { return []; } return groupingFunction(parts); }, [parts, groupingFunction]); }; const ToolUIDisplay = ({ Fallback, ...props }) => { const Render = useAssistantState(({ tools }) => { const Render = tools.tools[props.toolName] ?? Fallback; if (Array.isArray(Render)) return Render[0] ?? Fallback; return Render; }); if (!Render) return null; return _jsx(Render, { ...props }); }; const defaultComponents = { Text: () => (_jsxs("p", { style: { whiteSpace: "pre-line" }, children: [_jsx(MessagePartPrimitiveText, {}), _jsx(MessagePartPrimitiveInProgress, { children: _jsx("span", { style: { fontFamily: "revert" }, children: " \u25CF" }) })] })), Reasoning: () => null, Source: () => null, Image: () => _jsx(MessagePartPrimitiveImage, {}), File: () => null, Unstable_Audio: () => null, Group: ({ children }) => children, }; const MessagePartComponent = ({ components: { Text = defaultComponents.Text, Reasoning = defaultComponents.Reasoning, Image = defaultComponents.Image, Source = defaultComponents.Source, File = defaultComponents.File, Unstable_Audio: Audio = defaultComponents.Unstable_Audio, tools = {}, } = {}, }) => { const api = useAssistantApi(); const part = useAssistantState(({ part }) => part); const type = part.type; if (type === "tool-call") { const addResult = (result) => api.part().addToolResult(result); const resume = api.part().resumeToolCall; if ("Override" in tools) return _jsx(tools.Override, { ...part, addResult: addResult, resume: resume }); const Tool = tools.by_name?.[part.toolName] ?? tools.Fallback; return (_jsx(ToolUIDisplay, { ...part, Fallback: Tool, addResult: addResult, resume: resume })); } if (part.status?.type === "requires-action") throw new Error("Encountered unexpected requires-action status"); switch (type) { case "text": return _jsx(Text, { ...part }); case "reasoning": return _jsx(Reasoning, { ...part }); case "source": return _jsx(Source, { ...part }); case "image": return _jsx(Image, { ...part }); case "file": return _jsx(File, { ...part }); case "audio": return _jsx(Audio, { ...part }); case "data": return null; default: const unhandledType = type; throw new Error(`Unknown message part type: ${unhandledType}`); } }; const MessagePartImpl = ({ partIndex, components }) => { return (_jsx(PartByIndexProvider, { index: partIndex, children: _jsx(MessagePartComponent, { components: components }) })); }; const MessagePart = memo(MessagePartImpl, (prev, next) => prev.partIndex === next.partIndex && prev.components?.Text === next.components?.Text && prev.components?.Reasoning === next.components?.Reasoning && prev.components?.Source === next.components?.Source && prev.components?.Image === next.components?.Image && prev.components?.File === next.components?.File && prev.components?.Unstable_Audio === next.components?.Unstable_Audio && prev.components?.tools === next.components?.tools && prev.components?.Group === next.components?.Group); const EmptyPartFallback = ({ status, component: Component }) => { return (_jsx(TextMessagePartProvider, { text: "", isRunning: status.type === "running", children: _jsx(Component, { type: "text", text: "", status: status }) })); }; const COMPLETE_STATUS = Object.freeze({ type: "complete", }); const EmptyPartsImpl = ({ components }) => { const status = useAssistantState((s) => (s.message.status ?? COMPLETE_STATUS)); if (components?.Empty) return _jsx(components.Empty, { status: status }); return (_jsx(EmptyPartFallback, { status: status, component: components?.Text ?? defaultComponents.Text })); }; const EmptyParts = memo(EmptyPartsImpl, (prev, next) => prev.components?.Empty === next.components?.Empty && prev.components?.Text === next.components?.Text); /** * Renders the parts of a message grouped by a custom grouping function. * * This component allows you to group message parts based on any criteria you define. * The grouping function receives all message parts and returns an array of groups, * where each group has a key and an array of part indices. * * @example * ```tsx * // Group by parent ID (default behavior) * <MessagePrimitive.Unstable_PartsGrouped * components={{ * Text: ({ text }) => <p className="message-text">{text}</p>, * Image: ({ image }) => <img src={image} alt="Message image" />, * Group: ({ groupKey, indices, children }) => { * if (!groupKey) return <>{children}</>; * return ( * <div className="parent-group border rounded p-4"> * <h4>Parent ID: {groupKey}</h4> * {children} * </div> * ); * } * }} * /> * ``` * * @example * ```tsx * // Group by tool name * import { groupMessagePartsByToolName } from "@assistant-ui/react"; * * <MessagePrimitive.Unstable_PartsGrouped * groupingFunction={groupMessagePartsByToolName} * components={{ * Group: ({ groupKey, indices, children }) => { * if (!groupKey) return <>{children}</>; * return ( * <div className="tool-group"> * <h4>Tool: {groupKey}</h4> * {children} * </div> * ); * } * }} * /> * ``` */ export const MessagePrimitiveUnstable_PartsGrouped = ({ groupingFunction, components }) => { const contentLength = useAssistantState(({ message }) => message.parts.length); const messageGroups = useMessagePartsGrouped(groupingFunction); const partsElements = useMemo(() => { if (contentLength === 0) { return _jsx(EmptyParts, { components: components }); } return messageGroups.map((group, groupIndex) => { const GroupComponent = components?.Group ?? defaultComponents.Group; return (_jsx(GroupComponent, { groupKey: group.groupKey, indices: group.indices, children: group.indices.map((partIndex) => (_jsx(MessagePart, { partIndex: partIndex, components: components }, partIndex))) }, `group-${groupIndex}-${group.groupKey ?? "ungrouped"}`)); }); }, [messageGroups, components, contentLength]); return _jsx(_Fragment, { children: partsElements }); }; MessagePrimitiveUnstable_PartsGrouped.displayName = "MessagePrimitive.Unstable_PartsGrouped"; /** * Renders the parts of a message grouped by their parent ID. * This is a convenience wrapper around Unstable_PartsGrouped with parent ID grouping. * * @deprecated Use MessagePrimitive.Unstable_PartsGrouped instead for more flexibility */ export const MessagePrimitiveUnstable_PartsGroupedByParentId = ({ components, ...props }) => { return (_jsx(MessagePrimitiveUnstable_PartsGrouped, { ...props, components: components, groupingFunction: groupMessagePartsByParentId })); }; MessagePrimitiveUnstable_PartsGroupedByParentId.displayName = "MessagePrimitive.Unstable_PartsGroupedByParentId"; //# sourceMappingURL=MessagePartsGrouped.js.map