UNPKG

@ant-design/x-sdk

Version:

placeholder for @ant-design/x-sdk

283 lines (279 loc) 9.22 kB
import { useEvent } from 'rc-util'; import React, { useState } from 'react'; import { useChatStore } from "./store"; var MessageStatusEnum = /*#__PURE__*/function (MessageStatusEnum) { MessageStatusEnum["local"] = "local"; MessageStatusEnum["loading"] = "loading"; MessageStatusEnum["updating"] = "updating"; MessageStatusEnum["success"] = "success"; MessageStatusEnum["error"] = "error"; MessageStatusEnum["abort"] = "abort"; return MessageStatusEnum; }(MessageStatusEnum || {}); function toArray(item) { return Array.isArray(item) ? item : [item]; } const IsRequestingMap = new Map(); export default function useXChat(config) { const { defaultMessages, requestFallback, requestPlaceholder, parser, provider, conversationKey } = config; // ========================= Agent Messages ========================= const idRef = React.useRef(0); const requestHandlerRef = React.useRef(undefined); const [isRequesting, setIsRequesting] = useState(false); const { messages, setMessages, getMessages, setMessage } = useChatStore(() => (defaultMessages || []).map((info, index) => ({ id: `default_${index}`, status: 'local', ...info })), conversationKey); const createMessage = (message, status, extra) => { const msg = { id: `msg_${idRef.current}`, message, status }; if (extra) { msg.extra = extra; } idRef.current += 1; return msg; }; // ========================= BubbleMessages ========================= const parsedMessages = React.useMemo(() => { const list = []; messages.forEach(agentMsg => { const rawParsedMsg = parser ? parser(agentMsg.message) : agentMsg.message; const bubbleMsgs = toArray(rawParsedMsg); bubbleMsgs.forEach((bubbleMsg, bubbleMsgIndex) => { let key = agentMsg.id; if (bubbleMsgs.length > 1) { key = `${key}_${bubbleMsgIndex}`; } list.push({ id: key, message: bubbleMsg, status: agentMsg.status }); }); }); return list; }, [messages]); // ============================ Request ============================= const getFilteredMessages = msgs => msgs.filter(info => info.status !== 'loading').map(info => info.message); provider?.injectGetMessages(() => { return getFilteredMessages(getMessages()); }); requestHandlerRef.current = provider?.request; // For agent to use. Will filter out loading and error message const getRequestMessages = () => getFilteredMessages(getMessages()); const innerOnRequest = (requestParams, opts) => { if (!provider) { return; } const { updatingId, reload } = opts || {}; let loadingMsgId = null; const localMessage = provider.transformLocalMessage(requestParams); const messages = (Array.isArray(localMessage) ? localMessage : [localMessage]).map(message => createMessage(message, 'local', opts?.extra)); if (reload) { loadingMsgId = updatingId; setMessages(ori => { const nextMessages = [...ori]; if (requestPlaceholder) { let placeholderMsg; if (typeof requestPlaceholder === 'function') { // typescript has bug that not get real return type when use `typeof function` check placeholderMsg = requestPlaceholder(requestParams, { messages: getFilteredMessages(nextMessages) }); } else { placeholderMsg = requestPlaceholder; } nextMessages.forEach(info => { if (info.id === updatingId) { info.status = 'loading'; info.message = placeholderMsg; if (opts?.extra) { info.extra = opts?.extra; } } }); } return nextMessages; }); } else { // Add placeholder message setMessages(ori => { let nextMessages = [...ori, ...messages]; if (requestPlaceholder) { let placeholderMsg; if (typeof requestPlaceholder === 'function') { // typescript has bug that not get real return type when use `typeof function` check placeholderMsg = requestPlaceholder(requestParams, { messages: getFilteredMessages(nextMessages) }); } else { placeholderMsg = requestPlaceholder; } const loadingMsg = createMessage(placeholderMsg, 'loading'); loadingMsgId = loadingMsg.id; nextMessages = [...nextMessages, loadingMsg]; } return nextMessages; }); } // Request let updatingMsgId = null; const updateMessage = (status, chunk, chunks, responseHeaders) => { let msg = getMessages().find(info => info.id === updatingMsgId); if (!msg) { if (reload && updatingId) { msg = getMessages().find(info => info.id === updatingId); if (msg) { msg.status = status; msg.message = provider.transformMessage({ chunk, status, chunks, responseHeaders }); setMessages(ori => { return [...ori]; }); updatingMsgId = msg.id; } } else { // Create if not exist const transformData = provider.transformMessage({ chunk, status, chunks, responseHeaders }); msg = createMessage(transformData, status); setMessages(ori => { const oriWithoutPending = ori.filter(info => info.id !== loadingMsgId); return [...oriWithoutPending, msg]; }); updatingMsgId = msg.id; } } else { // Update directly setMessages(ori => { return ori.map(info => { if (info.id === updatingMsgId) { const transformData = provider.transformMessage({ originMessage: info.message, chunk, chunks, status, responseHeaders }); return { ...info, message: transformData, status }; } return info; }); }); } return msg; }; provider.injectRequest({ onUpdate: (chunk, headers) => { updateMessage('updating', chunk, [], headers); }, onSuccess: (chunks, headers) => { setIsRequesting(false); conversationKey && IsRequestingMap.delete(conversationKey); updateMessage('success', undefined, chunks, headers); }, onError: async error => { setIsRequesting(false); conversationKey && IsRequestingMap.delete(conversationKey); if (requestFallback) { let fallbackMsg; // Update as error if (typeof requestFallback === 'function') { // typescript has bug that not get real return type when use `typeof function` check const messages = getRequestMessages(); const msg = getMessages().find(info => info.id === loadingMsgId || info.id === updatingMsgId); fallbackMsg = await requestFallback(requestParams, { error, messageInfo: msg, messages }); } else { fallbackMsg = requestFallback; } setMessages(ori => [...ori.filter(info => info.id !== loadingMsgId && info.id !== updatingMsgId), createMessage(fallbackMsg, error.name === 'AbortError' ? 'abort' : 'error')]); } else { // Remove directly setMessages(ori => { return ori.map(info => { if (info.id === loadingMsgId || info.id === updatingMsgId) { return { ...info, status: error.name === 'AbortError' ? 'abort' : 'error' }; } return info; }); }); } } }); setIsRequesting(true); conversationKey && IsRequestingMap.set(conversationKey, true); provider.request.run(provider.transformParams(requestParams, provider.request.options)); }; const onRequest = useEvent((requestParams, opts) => { if (!provider) { throw new Error('provider is required'); } innerOnRequest(requestParams, opts); }); const onReload = (id, requestParams, opts) => { if (!provider) { throw new Error('provider is required'); } if (!id || !getMessages().find(info => info.id === id)) { throw new Error(`message [${id}] is not found`); } innerOnRequest(requestParams, { updatingId: id, reload: true, extra: opts?.extra }); }; return { onRequest, messages, parsedMessages, setMessages, setMessage, abort: () => { if (!provider) { throw new Error('provider is required'); } requestHandlerRef.current?.abort(); }, isRequesting: conversationKey ? IsRequestingMap?.get(conversationKey) || false : isRequesting, onReload }; }