@ant-design/x
Version:
Craft AI-driven interfaces effortlessly
152 lines (145 loc) • 4.75 kB
JavaScript
import { useEvent } from 'rc-util';
import React from 'react';
import useSyncState from "./useSyncState";
function toArray(item) {
return Array.isArray(item) ? item : [item];
}
export default function useXChat(config) {
const {
defaultMessages,
agent,
requestFallback,
requestPlaceholder,
parser
} = config;
// ========================= Agent Messages =========================
const idRef = React.useRef(0);
const [messages, setMessages, getMessages] = useSyncState(() => (defaultMessages || []).map((info, index) => ({
id: `default_${index}`,
status: 'local',
...info
})));
const createMessage = (message, status) => {
const msg = {
id: `msg_${idRef.current}`,
message,
status
};
idRef.current += 1;
return msg;
};
// ========================= BubbleMessages =========================
const parsedMessages = React.useMemo(() => {
const list = [];
messages.forEach(agentMsg => {
const rawParsedMsg = parser ? parser(agentMsg.message) : agentMsg.message;
const bubbleMsgs = toArray(rawParsedMsg);
bubbleMsgs.forEach((bubbleMsg, bubbleMsgIndex) => {
let key = agentMsg.id;
if (bubbleMsgs.length > 1) {
key = `${key}_${bubbleMsgIndex}`;
}
list.push({
id: key,
message: bubbleMsg,
status: agentMsg.status
});
});
});
return list;
}, [messages]);
// ============================ Request =============================
const getFilteredMessages = msgs => msgs.filter(info => info.status !== 'loading' && info.status !== 'error').map(info => info.message);
// For agent to use. Will filter out loading and error message
const getRequestMessages = () => getFilteredMessages(getMessages());
const onRequest = useEvent(message => {
if (!agent) throw new Error('The agent parameter is required when using the onRequest method in an agent generated by useXAgent.');
let loadingMsgId = null;
// Add placeholder message
setMessages(ori => {
let nextMessages = [...ori, createMessage(message, 'local')];
if (requestPlaceholder) {
let placeholderMsg;
if (typeof requestPlaceholder === 'function') {
// typescript has bug that not get real return type when use `typeof function` check
placeholderMsg = requestPlaceholder(message, {
messages: getFilteredMessages(nextMessages)
});
} else {
placeholderMsg = requestPlaceholder;
}
const loadingMsg = createMessage(placeholderMsg, 'loading');
loadingMsgId = loadingMsg.id;
nextMessages = [...nextMessages, loadingMsg];
}
return nextMessages;
});
// Request
let updatingMsgId = null;
const updateMessage = (message, status) => {
let msg = getMessages().find(info => info.id === updatingMsgId);
if (!msg) {
// Create if not exist
msg = createMessage(message, status);
setMessages(ori => {
const oriWithoutPending = ori.filter(info => info.id !== loadingMsgId);
return [...oriWithoutPending, msg];
});
updatingMsgId = msg.id;
} else {
// Update directly
setMessages(ori => {
return ori.map(info => {
if (info.id === updatingMsgId) {
return {
...info,
message,
status
};
}
return info;
});
});
}
return msg;
};
agent.request({
message,
messages: getRequestMessages()
}, {
onUpdate: message => {
updateMessage(message, 'loading');
},
onSuccess: message => {
updateMessage(message, 'success');
},
onError: async error => {
if (requestFallback) {
let fallbackMsg;
// Update as error
if (typeof requestFallback === 'function') {
// typescript has bug that not get real return type when use `typeof function` check
fallbackMsg = await requestFallback(message, {
error,
messages: getRequestMessages()
});
} else {
fallbackMsg = requestFallback;
}
setMessages(ori => [...ori.filter(info => info.id !== loadingMsgId && info.id !== updatingMsgId), createMessage(fallbackMsg, 'error')]);
} else {
// Remove directly
setMessages(ori => {
return ori.filter(info => info.id !== loadingMsgId && info.id !== updatingMsgId);
});
}
}
});
});
return {
onRequest,
messages,
parsedMessages,
setMessages
};
}