@assistant-ui/react
Version:
TypeScript/React library for AI Chat
240 lines • 10.9 kB
JavaScript
"use client";
import { useMemo } from "react";
import { ThreadMessageConverter } from "./ThreadMessageConverter.js";
import { getExternalStoreMessages, symbolInnerMessage, } from "./getExternalStoreMessage.js";
import { fromThreadMessageLike } from "./ThreadMessageLike.js";
import { getAutoStatus, isAutoStatus } from "./auto-status.js";
const joinExternalMessages = (messages) => {
const assistantMessage = {
role: "assistant",
content: [],
};
for (const output of messages) {
if (output.role === "tool") {
const toolCallIdx = assistantMessage.content.findIndex((c) => c.type === "tool-call" && c.toolCallId === output.toolCallId);
if (toolCallIdx !== -1) {
const toolCall = assistantMessage.content[toolCallIdx];
if (output.toolName !== undefined) {
if (toolCall.toolName !== output.toolName)
throw new Error(`Tool call name ${output.toolCallId} ${output.toolName} does not match existing tool call ${toolCall.toolName}`);
}
assistantMessage.content[toolCallIdx] = {
...toolCall,
...{
[symbolInnerMessage]: [
...(toolCall[symbolInnerMessage] ?? []),
output,
],
},
result: output.result,
artifact: output.artifact,
isError: output.isError,
messages: output.messages,
};
}
else {
throw new Error(`Tool call ${output.toolCallId} ${output.toolName} not found in assistant message`);
}
}
else {
const role = output.role;
const content = (typeof output.content === "string"
? [{ type: "text", text: output.content }]
: output.content).map((c) => ({
...c,
...{ [symbolInnerMessage]: [output] },
}));
switch (role) {
case "system":
case "user":
return {
...output,
content,
};
case "assistant":
if (assistantMessage.content.length === 0) {
assistantMessage.id = output.id;
assistantMessage.createdAt ??= output.createdAt;
assistantMessage.status ??= output.status;
if (output.attachments) {
assistantMessage.attachments = [
...(assistantMessage.attachments ?? []),
...output.attachments,
];
}
if (output.metadata) {
assistantMessage.metadata ??= {};
if (output.metadata.unstable_state) {
assistantMessage.metadata.unstable_state =
output.metadata.unstable_state;
}
if (output.metadata.unstable_annotations) {
assistantMessage.metadata.unstable_annotations = [
...(assistantMessage.metadata.unstable_annotations ?? []),
...output.metadata.unstable_annotations,
];
}
if (output.metadata.unstable_data) {
assistantMessage.metadata.unstable_data = [
...(assistantMessage.metadata.unstable_data ?? []),
...output.metadata.unstable_data,
];
}
if (output.metadata.steps) {
assistantMessage.metadata.steps = [
...(assistantMessage.metadata.steps ?? []),
...output.metadata.steps,
];
}
if (output.metadata.custom) {
assistantMessage.metadata.custom = {
...(assistantMessage.metadata.custom ?? {}),
...output.metadata.custom,
};
}
if (output.metadata.submittedFeedback) {
assistantMessage.metadata.submittedFeedback =
output.metadata.submittedFeedback;
}
}
// TODO keep this in sync
}
assistantMessage.content.push(...content);
break;
default: {
const unsupportedRole = role;
throw new Error(`Unknown message role: ${unsupportedRole}`);
}
}
}
}
return assistantMessage;
};
const chunkExternalMessages = (callbackResults, joinStrategy) => {
const results = [];
let isAssistant = false;
let pendingNone = false; // true if the previous assistant message had joinStrategy "none"
let inputs = [];
let outputs = [];
const flush = () => {
if (outputs.length) {
results.push({
inputs,
outputs,
});
}
inputs = [];
outputs = [];
isAssistant = false;
pendingNone = false;
};
for (const callbackResult of callbackResults) {
for (const output of callbackResult.outputs) {
if ((pendingNone && output.role !== "tool") ||
!isAssistant ||
output.role === "user" ||
output.role === "system") {
flush();
}
isAssistant = output.role === "assistant" || output.role === "tool";
if (inputs.at(-1) !== callbackResult.input) {
inputs.push(callbackResult.input);
}
outputs.push(output);
if (output.role === "assistant" &&
(output.convertConfig?.joinStrategy === "none" ||
joinStrategy === "none")) {
pendingNone = true;
}
}
}
flush();
return results;
};
export const convertExternalMessages = (messages, callback, isRunning, metadata) => {
const callbackResults = [];
for (const message of messages) {
const output = callback(message, metadata);
const outputs = Array.isArray(output) ? output : [output];
const result = { input: message, outputs };
callbackResults.push(result);
}
const chunks = chunkExternalMessages(callbackResults);
return chunks.map((message, idx) => {
const isLast = idx === chunks.length - 1;
const joined = joinExternalMessages(message.outputs);
const hasSuspendedToolCalls = typeof joined.content === "object" &&
joined.content.some((c) => c.type === "tool-call" && c.result === undefined);
const hasPendingToolCalls = typeof joined.content === "object" &&
joined.content.some((c) => c.type === "tool-call" && c.result === undefined);
const autoStatus = getAutoStatus(isLast, isRunning, hasSuspendedToolCalls, hasPendingToolCalls, isLast ? metadata.error : undefined);
const newMessage = fromThreadMessageLike(joined, idx.toString(), autoStatus);
newMessage[symbolInnerMessage] = message.inputs;
return newMessage;
});
};
export const useExternalMessageConverter = ({ callback, messages, isRunning, joinStrategy, metadata, }) => {
const state = useMemo(() => ({
metadata: metadata ?? {},
callback,
callbackCache: new WeakMap(),
chunkCache: new WeakMap(),
converterCache: new ThreadMessageConverter(),
}), [callback, metadata]);
return useMemo(() => {
const callbackResults = [];
for (const message of messages) {
let result = state.callbackCache.get(message);
if (!result) {
const output = state.callback(message, state.metadata);
const outputs = Array.isArray(output) ? output : [output];
result = { input: message, outputs };
state.callbackCache.set(message, result);
}
callbackResults.push(result);
}
const chunks = chunkExternalMessages(callbackResults, joinStrategy).map((m) => {
const key = m.outputs[0];
if (!key)
return m;
const cached = state.chunkCache.get(key);
if (cached && shallowArrayEqual(cached.outputs, m.outputs))
return cached;
state.chunkCache.set(key, m);
return m;
});
const threadMessages = state.converterCache.convertMessages(chunks, (cache, message, idx) => {
const isLast = idx === chunks.length - 1;
const joined = joinExternalMessages(message.outputs);
const hasSuspendedToolCalls = typeof joined.content === "object" &&
joined.content.some((c) => c.type === "tool-call" && c.result === undefined);
const hasPendingToolCalls = typeof joined.content === "object" &&
joined.content.some((c) => c.type === "tool-call" && c.result === undefined);
const autoStatus = getAutoStatus(isLast, isRunning, hasSuspendedToolCalls, hasPendingToolCalls, isLast ? state.metadata.error : undefined);
if (cache &&
(cache.role !== "assistant" ||
!isAutoStatus(cache.status) ||
cache.status === autoStatus)) {
const inputs = getExternalStoreMessages(cache);
if (shallowArrayEqual(inputs, message.inputs)) {
return cache;
}
}
const newMessage = fromThreadMessageLike(joined, idx.toString(), autoStatus);
newMessage[symbolInnerMessage] = message.inputs;
return newMessage;
});
threadMessages[symbolInnerMessage] = messages;
return threadMessages;
}, [state, messages, isRunning, joinStrategy]);
};
const shallowArrayEqual = (a, b) => {
if (a.length !== b.length)
return false;
for (let i = 0; i < a.length; i++) {
if (a[i] !== b[i])
return false;
}
return true;
};
//# sourceMappingURL=external-message-converter.js.map