@assistant-ui/react
Version:
TypeScript/React library for AI Chat
227 lines • 9.83 kB
JavaScript
"use client";
import { jsx as _jsx, jsxs as _jsxs, Fragment as _Fragment } from "react/jsx-runtime";
import { memo, useMemo, } from "react";
import { useAssistantState, useAssistantApi, PartByIndexProvider, TextMessagePartProvider, } from "../../context/index.js";
import { MessagePartPrimitiveText } from "../messagePart/MessagePartText.js";
import { MessagePartPrimitiveImage } from "../messagePart/MessagePartImage.js";
import { MessagePartPrimitiveInProgress } from "../messagePart/MessagePartInProgress.js";
import { useShallow } from "zustand/shallow";
/**
* Creates a group state manager for a specific part type.
* Returns functions to start, end, and finalize groups.
*/
const createGroupState = (groupType) => {
let start = -1;
return {
startGroup: (index) => {
if (start === -1) {
start = index;
}
},
endGroup: (endIndex, ranges) => {
if (start !== -1) {
ranges.push({
type: groupType,
startIndex: start,
endIndex,
});
start = -1;
}
},
finalize: (endIndex, ranges) => {
if (start !== -1) {
ranges.push({
type: groupType,
startIndex: start,
endIndex,
});
}
},
};
};
/**
* Groups consecutive tool-call and reasoning message parts into ranges.
* Always groups tool calls and reasoning parts, even if there's only one.
*/
const groupMessageParts = (messageTypes) => {
const ranges = [];
const toolGroup = createGroupState("toolGroup");
const reasoningGroup = createGroupState("reasoningGroup");
for (let i = 0; i < messageTypes.length; i++) {
const type = messageTypes[i];
if (type === "tool-call") {
reasoningGroup.endGroup(i - 1, ranges);
toolGroup.startGroup(i);
}
else if (type === "reasoning") {
toolGroup.endGroup(i - 1, ranges);
reasoningGroup.startGroup(i);
}
else {
toolGroup.endGroup(i - 1, ranges);
reasoningGroup.endGroup(i - 1, ranges);
ranges.push({ type: "single", index: i });
}
}
toolGroup.finalize(messageTypes.length - 1, ranges);
reasoningGroup.finalize(messageTypes.length - 1, ranges);
return ranges;
};
const useMessagePartsGroups = () => {
const messageTypes = useAssistantState(useShallow((s) => s.message.parts.map((c) => c.type)));
return useMemo(() => {
if (messageTypes.length === 0) {
return [];
}
return groupMessageParts(messageTypes);
}, [messageTypes]);
};
const ToolUIDisplay = ({ Fallback, ...props }) => {
const Render = useAssistantState(({ tools }) => {
const Render = tools.tools[props.toolName] ?? Fallback;
if (Array.isArray(Render))
return Render[0] ?? Fallback;
return Render;
});
if (!Render)
return null;
return _jsx(Render, { ...props });
};
const defaultComponents = {
Text: () => (_jsxs("p", { style: { whiteSpace: "pre-line" }, children: [_jsx(MessagePartPrimitiveText, {}), _jsx(MessagePartPrimitiveInProgress, { children: _jsx("span", { style: { fontFamily: "revert" }, children: " \u25CF" }) })] })),
Reasoning: () => null,
Source: () => null,
Image: () => _jsx(MessagePartPrimitiveImage, {}),
File: () => null,
Unstable_Audio: () => null,
ToolGroup: ({ children }) => children,
ReasoningGroup: ({ children }) => children,
};
const MessagePartComponent = ({ components: { Text = defaultComponents.Text, Reasoning = defaultComponents.Reasoning, Image = defaultComponents.Image, Source = defaultComponents.Source, File = defaultComponents.File, Unstable_Audio: Audio = defaultComponents.Unstable_Audio, tools = {}, } = {}, }) => {
const api = useAssistantApi();
const part = useAssistantState(({ part }) => part);
const type = part.type;
if (type === "tool-call") {
const addResult = api.part().addToolResult;
const resume = api.part().resumeToolCall;
if ("Override" in tools)
return _jsx(tools.Override, { ...part, addResult: addResult, resume: resume });
const Tool = tools.by_name?.[part.toolName] ?? tools.Fallback;
return (_jsx(ToolUIDisplay, { ...part, Fallback: Tool, addResult: addResult, resume: resume }));
}
if (part.status?.type === "requires-action")
throw new Error("Encountered unexpected requires-action status");
switch (type) {
case "text":
return _jsx(Text, { ...part });
case "reasoning":
return _jsx(Reasoning, { ...part });
case "source":
return _jsx(Source, { ...part });
case "image":
return _jsx(Image, { ...part });
case "file":
return _jsx(File, { ...part });
case "audio":
return _jsx(Audio, { ...part });
case "data":
return null;
default:
const unhandledType = type;
throw new Error(`Unknown message part type: ${unhandledType}`);
}
};
/**
* Renders a single message part at the specified index.
*
* This component provides direct access to render a specific message part
* within the current message context, using the provided component configuration.
*
* @example
* ```tsx
* <MessagePrimitive.PartByIndex
* index={0}
* components={{
* Text: MyTextComponent,
* Image: MyImageComponent
* }}
* />
* ```
*/
export const MessagePrimitivePartByIndex = memo(({ index, components }) => {
return (_jsx(PartByIndexProvider, { index: index, children: _jsx(MessagePartComponent, { components: components }) }));
}, (prev, next) => prev.index === next.index &&
prev.components?.Text === next.components?.Text &&
prev.components?.Reasoning === next.components?.Reasoning &&
prev.components?.Source === next.components?.Source &&
prev.components?.Image === next.components?.Image &&
prev.components?.File === next.components?.File &&
prev.components?.Unstable_Audio === next.components?.Unstable_Audio &&
prev.components?.tools === next.components?.tools &&
prev.components?.ToolGroup === next.components?.ToolGroup &&
prev.components?.ReasoningGroup === next.components?.ReasoningGroup);
MessagePrimitivePartByIndex.displayName = "MessagePrimitive.PartByIndex";
const EmptyPartFallback = ({ status, component: Component }) => {
return (_jsx(TextMessagePartProvider, { text: "", isRunning: status.type === "running", children: _jsx(Component, { type: "text", text: "", status: status }) }));
};
const COMPLETE_STATUS = Object.freeze({
type: "complete",
});
const EmptyPartsImpl = ({ components }) => {
const status = useAssistantState((s) => (s.message.status ?? COMPLETE_STATUS));
if (components?.Empty)
return _jsx(components.Empty, { status: status });
return (_jsx(EmptyPartFallback, { status: status, component: components?.Text ?? defaultComponents.Text }));
};
const EmptyParts = memo(EmptyPartsImpl, (prev, next) => prev.components?.Empty === next.components?.Empty &&
prev.components?.Text === next.components?.Text);
/**
* Renders the parts of a message with support for multiple content types.
*
* This component automatically handles different types of message content including
* text, images, files, tool calls, and more. It provides a flexible component
* system for customizing how each content type is rendered.
*
* @example
* ```tsx
* <MessagePrimitive.Parts
* components={{
* Text: ({ text }) => <p className="message-text">{text}</p>,
* Image: ({ image }) => <img src={image} alt="Message image" />,
* tools: {
* by_name: {
* calculator: CalculatorTool,
* weather: WeatherTool,
* },
* Fallback: DefaultToolComponent
* }
* }}
* />
* ```
*/
export const MessagePrimitiveParts = ({ components, }) => {
const contentLength = useAssistantState(({ message }) => message.parts.length);
const messageRanges = useMessagePartsGroups();
const partsElements = useMemo(() => {
if (contentLength === 0) {
return _jsx(EmptyParts, { components: components });
}
return messageRanges.map((range) => {
if (range.type === "single") {
return (_jsx(MessagePrimitivePartByIndex, { index: range.index, components: components }, range.index));
}
else if (range.type === "toolGroup") {
const ToolGroupComponent = components?.ToolGroup ?? defaultComponents.ToolGroup;
return (_jsx(ToolGroupComponent, { startIndex: range.startIndex, endIndex: range.endIndex, children: Array.from({ length: range.endIndex - range.startIndex + 1 }, (_, i) => (_jsx(MessagePrimitivePartByIndex, { index: range.startIndex + i, components: components }, i))) }, `tool-${range.startIndex}`));
}
else {
// reasoningGroup
const ReasoningGroupComponent = components?.ReasoningGroup ?? defaultComponents.ReasoningGroup;
return (_jsx(ReasoningGroupComponent, { startIndex: range.startIndex, endIndex: range.endIndex, children: Array.from({ length: range.endIndex - range.startIndex + 1 }, (_, i) => (_jsx(MessagePrimitivePartByIndex, { index: range.startIndex + i, components: components }, i))) }, `reasoning-${range.startIndex}`));
}
});
}, [messageRanges, components, contentLength]);
return _jsx(_Fragment, { children: partsElements });
};
MessagePrimitiveParts.displayName = "MessagePrimitive.Parts";
//# sourceMappingURL=MessageParts.js.map