@assistant-ui/react
Version:
TypeScript/React library for AI Chat
200 lines • 9.13 kB
JavaScript
"use client";
import { jsx as _jsx, jsxs as _jsxs, Fragment as _Fragment } from "react/jsx-runtime";
import { memo, useMemo, } from "react";
import { useAssistantState, PartByIndexProvider, useAssistantApi, TextMessagePartProvider, } from "../../context/index.js";
import { MessagePartPrimitiveText } from "../messagePart/MessagePartText.js";
import { MessagePartPrimitiveImage } from "../messagePart/MessagePartImage.js";
import { MessagePartPrimitiveInProgress } from "../messagePart/MessagePartInProgress.js";
/**
* Groups message parts by their parent ID.
* Parts without a parent ID appear in their chronological position as individual groups.
* Parts with the same parent ID are grouped together at the position of their first occurrence.
*/
const groupMessagePartsByParentId = (parts) => {
// Map maintains insertion order, so groups appear in order of first occurrence
const groupMap = new Map();
// Process each part in order
for (let i = 0; i < parts.length; i++) {
const part = parts[i];
const parentId = part?.parentId;
// For parts without parentId, assign a unique group ID to maintain their position
const groupId = parentId ?? `__ungrouped_${i}`;
// Get or create the indices array for this group
const indices = groupMap.get(groupId) ?? [];
indices.push(i);
groupMap.set(groupId, indices);
}
// Convert map to array of groups
const groups = [];
for (const [groupId, indices] of groupMap) {
// Extract parentId (undefined for ungrouped parts)
const groupKey = groupId.startsWith("__ungrouped_") ? undefined : groupId;
groups.push({ groupKey, indices });
}
return groups;
};
const useMessagePartsGrouped = (groupingFunction) => {
const parts = useAssistantState(({ message }) => message.parts);
return useMemo(() => {
if (parts.length === 0) {
return [];
}
return groupingFunction(parts);
}, [parts, groupingFunction]);
};
const ToolUIDisplay = ({ Fallback, ...props }) => {
const Render = useAssistantState(({ tools }) => {
const Render = tools.tools[props.toolName] ?? Fallback;
if (Array.isArray(Render))
return Render[0] ?? Fallback;
return Render;
});
if (!Render)
return null;
return _jsx(Render, { ...props });
};
const defaultComponents = {
Text: () => (_jsxs("p", { style: { whiteSpace: "pre-line" }, children: [_jsx(MessagePartPrimitiveText, {}), _jsx(MessagePartPrimitiveInProgress, { children: _jsx("span", { style: { fontFamily: "revert" }, children: " \u25CF" }) })] })),
Reasoning: () => null,
Source: () => null,
Image: () => _jsx(MessagePartPrimitiveImage, {}),
File: () => null,
Unstable_Audio: () => null,
Group: ({ children }) => children,
};
const MessagePartComponent = ({ components: { Text = defaultComponents.Text, Reasoning = defaultComponents.Reasoning, Image = defaultComponents.Image, Source = defaultComponents.Source, File = defaultComponents.File, Unstable_Audio: Audio = defaultComponents.Unstable_Audio, tools = {}, } = {}, }) => {
const api = useAssistantApi();
const part = useAssistantState(({ part }) => part);
const type = part.type;
if (type === "tool-call") {
const addResult = (result) => api.part().addToolResult(result);
const resume = api.part().resumeToolCall;
if ("Override" in tools)
return _jsx(tools.Override, { ...part, addResult: addResult, resume: resume });
const Tool = tools.by_name?.[part.toolName] ?? tools.Fallback;
return (_jsx(ToolUIDisplay, { ...part, Fallback: Tool, addResult: addResult, resume: resume }));
}
if (part.status?.type === "requires-action")
throw new Error("Encountered unexpected requires-action status");
switch (type) {
case "text":
return _jsx(Text, { ...part });
case "reasoning":
return _jsx(Reasoning, { ...part });
case "source":
return _jsx(Source, { ...part });
case "image":
return _jsx(Image, { ...part });
case "file":
return _jsx(File, { ...part });
case "audio":
return _jsx(Audio, { ...part });
case "data":
return null;
default:
const unhandledType = type;
throw new Error(`Unknown message part type: ${unhandledType}`);
}
};
const MessagePartImpl = ({ partIndex, components }) => {
return (_jsx(PartByIndexProvider, { index: partIndex, children: _jsx(MessagePartComponent, { components: components }) }));
};
const MessagePart = memo(MessagePartImpl, (prev, next) => prev.partIndex === next.partIndex &&
prev.components?.Text === next.components?.Text &&
prev.components?.Reasoning === next.components?.Reasoning &&
prev.components?.Source === next.components?.Source &&
prev.components?.Image === next.components?.Image &&
prev.components?.File === next.components?.File &&
prev.components?.Unstable_Audio === next.components?.Unstable_Audio &&
prev.components?.tools === next.components?.tools &&
prev.components?.Group === next.components?.Group);
const EmptyPartFallback = ({ status, component: Component }) => {
return (_jsx(TextMessagePartProvider, { text: "", isRunning: status.type === "running", children: _jsx(Component, { type: "text", text: "", status: status }) }));
};
const COMPLETE_STATUS = Object.freeze({
type: "complete",
});
const EmptyPartsImpl = ({ components }) => {
const status = useAssistantState((s) => (s.message.status ?? COMPLETE_STATUS));
if (components?.Empty)
return _jsx(components.Empty, { status: status });
return (_jsx(EmptyPartFallback, { status: status, component: components?.Text ?? defaultComponents.Text }));
};
const EmptyParts = memo(EmptyPartsImpl, (prev, next) => prev.components?.Empty === next.components?.Empty &&
prev.components?.Text === next.components?.Text);
/**
* Renders the parts of a message grouped by a custom grouping function.
*
* This component allows you to group message parts based on any criteria you define.
* The grouping function receives all message parts and returns an array of groups,
* where each group has a key and an array of part indices.
*
* @example
* ```tsx
* // Group by parent ID (default behavior)
* <MessagePrimitive.Unstable_PartsGrouped
* components={{
* Text: ({ text }) => <p className="message-text">{text}</p>,
* Image: ({ image }) => <img src={image} alt="Message image" />,
* Group: ({ groupKey, indices, children }) => {
* if (!groupKey) return <>{children}</>;
* return (
* <div className="parent-group border rounded p-4">
* <h4>Parent ID: {groupKey}</h4>
* {children}
* </div>
* );
* }
* }}
* />
* ```
*
* @example
* ```tsx
* // Group by tool name
* import { groupMessagePartsByToolName } from "@assistant-ui/react";
*
* <MessagePrimitive.Unstable_PartsGrouped
* groupingFunction={groupMessagePartsByToolName}
* components={{
* Group: ({ groupKey, indices, children }) => {
* if (!groupKey) return <>{children}</>;
* return (
* <div className="tool-group">
* <h4>Tool: {groupKey}</h4>
* {children}
* </div>
* );
* }
* }}
* />
* ```
*/
export const MessagePrimitiveUnstable_PartsGrouped = ({ groupingFunction, components }) => {
const contentLength = useAssistantState(({ message }) => message.parts.length);
const messageGroups = useMessagePartsGrouped(groupingFunction);
const partsElements = useMemo(() => {
if (contentLength === 0) {
return _jsx(EmptyParts, { components: components });
}
return messageGroups.map((group, groupIndex) => {
const GroupComponent = components?.Group ?? defaultComponents.Group;
return (_jsx(GroupComponent, { groupKey: group.groupKey, indices: group.indices, children: group.indices.map((partIndex) => (_jsx(MessagePart, { partIndex: partIndex, components: components }, partIndex))) }, `group-${groupIndex}-${group.groupKey ?? "ungrouped"}`));
});
}, [messageGroups, components, contentLength]);
return _jsx(_Fragment, { children: partsElements });
};
MessagePrimitiveUnstable_PartsGrouped.displayName =
"MessagePrimitive.Unstable_PartsGrouped";
/**
* Renders the parts of a message grouped by their parent ID.
* This is a convenience wrapper around Unstable_PartsGrouped with parent ID grouping.
*
* @deprecated Use MessagePrimitive.Unstable_PartsGrouped instead for more flexibility
*/
export const MessagePrimitiveUnstable_PartsGroupedByParentId = ({ components, ...props }) => {
return (_jsx(MessagePrimitiveUnstable_PartsGrouped, { ...props, components: components, groupingFunction: groupMessagePartsByParentId }));
};
MessagePrimitiveUnstable_PartsGroupedByParentId.displayName =
"MessagePrimitive.Unstable_PartsGroupedByParentId";
//# sourceMappingURL=MessagePartsGrouped.js.map