UNPKG

@ant-design/pro-chat

Version:
613 lines (600 loc) 26.6 kB
import _slicedToArray from "@babel/runtime/helpers/esm/slicedToArray"; import _typeof from "@babel/runtime/helpers/esm/typeof"; import _objectSpread from "@babel/runtime/helpers/esm/objectSpread2"; import _regeneratorRuntime from "@babel/runtime/helpers/esm/regeneratorRuntime"; import _asyncToGenerator from "@babel/runtime/helpers/esm/asyncToGenerator"; import { merge, template } from 'lodash-es'; import { LOADING_FLAT } from "../const/message"; import { fetchSSE } from "../utils/fetch"; import { isFunctionMessage } from "../utils/message"; import { setNamespace } from "../utils/storeDebug"; import { nanoid } from "../utils/uuid"; import { initialModelConfig } from "./initialState"; import { getSlicedMessagesWithConfig } from "../utils/message"; import { messagesReducer } from "./reducers/message"; import { chatSelectors } from "./selectors"; var t = setNamespace('chat/message'); /** * 聊天操作 */ export var chatAction = function chatAction(set, get) { return { clearMessage: function () { var _clearMessage = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() { var _get, dispatchMessage, onResetMessage; return _regeneratorRuntime().wrap(function _callee$(_context) { while (1) switch (_context.prev = _context.next) { case 0: _get = get(), dispatchMessage = _get.dispatchMessage, onResetMessage = _get.onResetMessage; // 重置消息,清空聊天记录,等待 onResetMessage 完成后再清空 if (!onResetMessage) { _context.next = 4; break; } _context.next = 4; return onResetMessage(); case 4: dispatchMessage({ type: 'resetMessages' }); // TODO: need callback after reset case 5: case "end": return _context.stop(); } }, _callee); })); function clearMessage() { return _clearMessage.apply(this, arguments); } return clearMessage; }(), deleteMessage: function deleteMessage(id) { get().dispatchMessage({ id: id, type: 'deleteMessage' }); }, updateMessageContent: function () { var _updateMessageContent = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(id, content) { var _get2, dispatchMessage, updateMessageContent; return _regeneratorRuntime().wrap(function _callee2$(_context2) { while (1) switch (_context2.prev = _context2.next) { case 0: _get2 = get(), dispatchMessage = _get2.dispatchMessage, updateMessageContent = _get2.updateMessageContent; dispatchMessage({ id: id, key: 'content', type: 'updateMessage', value: content }); updateMessageContent(id, content); case 3: case "end": return _context2.stop(); } }, _callee2); })); function updateMessageContent(_x, _x2) { return _updateMessageContent.apply(this, arguments); } return updateMessageContent; }(), dispatchMessage: function dispatchMessage(payload) { var _get3 = get(), chats = _get3.chats, onChatsChange = _get3.onChatsChange; var nextChats = messagesReducer(chats, payload); set({ chats: nextChats }, false, t('dispatchMessage')); onChatsChange === null || onChatsChange === void 0 || onChatsChange(nextChats); }, generateMessage: function () { var _generateMessage = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee6(messages, assistantId) { var _get4, dispatchMessage, toggleChatLoading, config, defaultModelFetcher, createSmoothMessage, deleteMessage, transformToChatMessage, onChatEnd, onChatStart, onChatGenerate, abortController, slicedMessages, compilerMessages, postMessages, fetcher, output, isFunctionCall, _createSmoothMessage, startAnimation, stopAnimation, outputQueue, isAnimationActive, mixRequestResponse, timeoutId, checkAndToggleChatLoading; return _regeneratorRuntime().wrap(function _callee6$(_context6) { while (1) switch (_context6.prev = _context6.next) { case 0: _get4 = get(), dispatchMessage = _get4.dispatchMessage, toggleChatLoading = _get4.toggleChatLoading, config = _get4.config, defaultModelFetcher = _get4.defaultModelFetcher, createSmoothMessage = _get4.createSmoothMessage, deleteMessage = _get4.deleteMessage, transformToChatMessage = _get4.transformToChatMessage, onChatEnd = _get4.onChatEnd, onChatStart = _get4.onChatStart, onChatGenerate = _get4.onChatGenerate; abortController = toggleChatLoading(true, assistantId, t('generateMessage(start)', { assistantId: assistantId, messages: messages })); // ========================== // // 对 messages 做统一预处理 // // ========================== // // 1. 按参数设定截断长度 slicedMessages = getSlicedMessagesWithConfig(messages, config); // 2. 替换 inputMessage 模板 compilerMessages = function compilerMessages(slicedMessages) { var compiler = template(config.inputTemplate, { interpolate: /{{([\S\s]+?)}}/g }); return slicedMessages.map(function (m) { if (m.role === 'user') { try { return _objectSpread(_objectSpread({}, m), {}, { content: compiler({ text: m.content }) }); } catch (error) { console.error(error); return m; } } return m; }); }; postMessages = !config.inputTemplate ? slicedMessages : compilerMessages(slicedMessages); // 3. 添加 systemRole if (config.systemRole) { postMessages.unshift({ content: config.systemRole, role: 'system' }); } if (onChatStart) { onChatStart(postMessages); } fetcher = function fetcher() { return defaultModelFetcher(_objectSpread({ messages: postMessages, model: config.model }, config.params), { signal: abortController === null || abortController === void 0 ? void 0 : abortController.signal }); }; output = ''; isFunctionCall = false; _createSmoothMessage = createSmoothMessage(assistantId), startAnimation = _createSmoothMessage.startAnimation, stopAnimation = _createSmoothMessage.stopAnimation, outputQueue = _createSmoothMessage.outputQueue, isAnimationActive = _createSmoothMessage.isAnimationActive, mixRequestResponse = _createSmoothMessage.mixRequestResponse; _context6.next = 13; return fetchSSE(fetcher, { signal: abortController === null || abortController === void 0 ? void 0 : abortController.signal, onCancel: function onCancel() { // cancel 时候删除 Loading 态的消息 deleteMessage(assistantId); }, onErrorHandle: function onErrorHandle(error) { console.log('error!'); dispatchMessage({ id: assistantId, key: 'error', type: 'updateMessage', value: error }); }, onAbort: function () { var _onAbort = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() { return _regeneratorRuntime().wrap(function _callee3$(_context3) { while (1) switch (_context3.prev = _context3.next) { case 0: if (onChatEnd) { onChatEnd(assistantId, 'abort'); } stopAnimation(); case 2: case "end": return _context3.stop(); } }, _callee3); })); function onAbort() { return _onAbort.apply(this, arguments); } return onAbort; }(), onFinish: function () { var _onFinish = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(type) { return _regeneratorRuntime().wrap(function _callee4$(_context4) { while (1) switch (_context4.prev = _context4.next) { case 0: stopAnimation(); if (onChatEnd) { onChatEnd(assistantId, type); } if (!(outputQueue.length > 0 && !isFunctionCall)) { _context4.next = 5; break; } _context4.next = 5; return startAnimation(15); case 5: case "end": return _context4.stop(); } }, _callee4); })); function onFinish(_x5) { return _onFinish.apply(this, arguments); } return onFinish; }(), onMessageHandle: function () { var _onMessageHandle = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(text, response) { var _i, _Object$entries, _Object$entries$_i, key, value, TransFormChatMessage; return _regeneratorRuntime().wrap(function _callee5$(_context5) { while (1) switch (_context5.prev = _context5.next) { case 0: output += text; if (response && _typeof(response) === 'object' && 'content' in response) { for (_i = 0, _Object$entries = Object.entries(response); _i < _Object$entries.length; _i++) { _Object$entries$_i = _slicedToArray(_Object$entries[_i], 2), key = _Object$entries$_i[0], value = _Object$entries$_i[1]; mixRequestResponse[key] = value; } } if (!isAnimationActive && !isFunctionCall) startAnimation(); if (!(abortController !== null && abortController !== void 0 && abortController.signal.aborted)) { _context5.next = 7; break; } return _context5.abrupt("return"); case 7: if (!transformToChatMessage) { _context5.next = 13; break; } _context5.next = 10; return transformToChatMessage(text, output); case 10: _context5.t0 = _context5.sent; _context5.next = 14; break; case 13: _context5.t0 = text; case 14: TransFormChatMessage = _context5.t0; if (onChatGenerate) { onChatGenerate(TransFormChatMessage); } outputQueue.push(TransFormChatMessage); case 17: // TODO: need a function call judge callback // 如果是 function call if (isFunctionMessage(output)) { isFunctionCall = true; } case 18: case "end": return _context5.stop(); } }, _callee5); })); function onMessageHandle(_x6, _x7) { return _onMessageHandle.apply(this, arguments); } return onMessageHandle; }() }); case 13: // 用于存储轮询队列的计时器id checkAndToggleChatLoading = function checkAndToggleChatLoading() { clearTimeout(timeoutId); // 清除任何现有的计时器 // 等待队列内容输出完毕 if (outputQueue === undefined || outputQueue.length === 0 || outputQueue.toString() === '') { // 当队列为空时 toggleChatLoading(false, undefined, t('generateMessage(end)')); clearTimeout(timeoutId); } else { // 如果队列不为空,则设置一个延迟或者使用某种形式的轮询来再次检查队列 timeoutId = setTimeout(checkAndToggleChatLoading, 30); // CHECK_INTERVAL 是毫秒数,代表检查间隔时间 } }; checkAndToggleChatLoading(); return _context6.abrupt("return", { isFunctionCall: isFunctionCall }); case 16: case "end": return _context6.stop(); } }, _callee6); })); function generateMessage(_x3, _x4) { return _generateMessage.apply(this, arguments); } return generateMessage; }(), realFetchAIResponse: function () { var _realFetchAIResponse = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee7(messages, userMessageId) { var _get5, dispatchMessage, generateMessage, config, getMessageId, mid; return _regeneratorRuntime().wrap(function _callee7$(_context7) { while (1) switch (_context7.prev = _context7.next) { case 0: _get5 = get(), dispatchMessage = _get5.dispatchMessage, generateMessage = _get5.generateMessage, config = _get5.config, getMessageId = _get5.getMessageId; // 添加一个空的信息用于放置 ai 响应,注意顺序不能反 // 因为如果顺序反了,messages 中将包含新增的 ai message _context7.next = 3; return getMessageId(messages, userMessageId); case 3: mid = _context7.sent; dispatchMessage({ id: mid, message: LOADING_FLAT, parentId: userMessageId, role: 'assistant', type: 'addMessage' }); // TODO: need a callback before generate message // 为模型添加 fromModel 的额外信息 // TODO: 此处需要model 信息 dispatchMessage({ id: mid, key: 'fromModel', type: 'updateMessageExtra', value: config.model }); // 生成 ai message _context7.next = 8; return generateMessage(messages, mid); case 8: case "end": return _context7.stop(); } }, _callee7); })); function realFetchAIResponse(_x8, _x9) { return _realFetchAIResponse.apply(this, arguments); } return realFetchAIResponse; }(), resendMessage: function () { var _resendMessage = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee8(messageId) { var chats, currentIndex, currentMessage, contextMessages, userId, userIndex, _get6, realFetchAIResponse, latestMsg; return _regeneratorRuntime().wrap(function _callee8$(_context8) { while (1) switch (_context8.prev = _context8.next) { case 0: // 1. 构造所有相关的历史记录 chats = chatSelectors.currentChats(get()); currentIndex = chats.findIndex(function (c) { return c.id === messageId; }); if (!(currentIndex < 0)) { _context8.next = 4; break; } return _context8.abrupt("return"); case 4: currentMessage = chats[currentIndex]; contextMessages = []; _context8.t0 = currentMessage.role; _context8.next = _context8.t0 === 'function' ? 9 : _context8.t0 === 'user' ? 9 : _context8.t0 === 'assistant' ? 11 : 15; break; case 9: contextMessages = chats.slice(0, currentIndex + 1); return _context8.abrupt("break", 15); case 11: // 消息是 AI 发出的因此需要找到它的 user 消息 userId = currentMessage.parentId; userIndex = chats.findIndex(function (c) { return c.id === userId; }); // 如果消息没有 parentId,那么同 user/function 模式 contextMessages = chats.slice(0, userIndex < 0 ? currentIndex + 1 : userIndex + 1); return _context8.abrupt("break", 15); case 15: if (!(contextMessages.length <= 0)) { _context8.next = 17; break; } return _context8.abrupt("return"); case 17: _get6 = get(), realFetchAIResponse = _get6.realFetchAIResponse; latestMsg = contextMessages.filter(function (s) { return s.role === 'user'; }).at(-1); if (latestMsg) { _context8.next = 21; break; } return _context8.abrupt("return"); case 21: _context8.next = 23; return realFetchAIResponse(contextMessages, latestMsg.id); case 23: case "end": return _context8.stop(); } }, _callee8); })); function resendMessage(_x10) { return _resendMessage.apply(this, arguments); } return resendMessage; }(), sendMessage: function () { var _sendMessage = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee9(message) { var _get7, dispatchMessage, realFetchAIResponse, userId, messages; return _regeneratorRuntime().wrap(function _callee9$(_context9) { while (1) switch (_context9.prev = _context9.next) { case 0: _get7 = get(), dispatchMessage = _get7.dispatchMessage, realFetchAIResponse = _get7.realFetchAIResponse; if (message) { _context9.next = 3; break; } return _context9.abrupt("return"); case 3: userId = nanoid(); dispatchMessage({ id: userId, message: message, role: 'user', type: 'addMessage' }); // Todo: need a callback before send message // Get the current messages to generate AI response messages = chatSelectors.currentChats(get()); _context9.next = 8; return realFetchAIResponse(messages, userId); case 8: case "end": return _context9.stop(); } }, _callee9); })); function sendMessage(_x11) { return _sendMessage.apply(this, arguments); } return sendMessage; }(), stopGenerateMessage: function stopGenerateMessage() { var _get8 = get(), abortController = _get8.abortController, toggleChatLoading = _get8.toggleChatLoading, chatLoadingId = _get8.chatLoadingId, chats = _get8.chats, dispatchMessage = _get8.dispatchMessage; // 如果当前 最后一条为 chatLoadingId 停止前需要清空 if (chats && chats.length > 0) { var lastChat = chats[chats.length - 1]; if (lastChat.content === LOADING_FLAT) { dispatchMessage({ id: chatLoadingId, key: 'content', type: 'updateMessage', value: '' }); } } if (!abortController) return; abortController.abort(); toggleChatLoading(false); }, toggleChatLoading: function toggleChatLoading(loading, id, action) { if (loading) { var _abortController = new AbortController(); set({ abortController: _abortController, chatLoadingId: id }, false, action); return _abortController; } else { set({ abortController: undefined, chatLoadingId: undefined }, false, action); } }, defaultModelFetcher: function defaultModelFetcher(params, options) { var _get9 = get(), request = _get9.request; var payload = merge(_objectSpread({ model: initialModelConfig.model, stream: true }, initialModelConfig.params), params); if (typeof request === 'function') return request(payload.messages, payload, options === null || options === void 0 ? void 0 : options.signal); var url = typeof request === 'string' ? request : '/api/openai/chat'; return fetch(url, { body: JSON.stringify(payload), headers: { 'Content-Type': 'application/json' }, method: 'POST', signal: options === null || options === void 0 ? void 0 : options.signal }); }, getMessageId: function () { var _getMessageId = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee10(messages, parentId) { var _get10, genMessageId; return _regeneratorRuntime().wrap(function _callee10$(_context10) { while (1) switch (_context10.prev = _context10.next) { case 0: _get10 = get(), genMessageId = _get10.genMessageId; if (!(typeof genMessageId === 'function')) { _context10.next = 3; break; } return _context10.abrupt("return", genMessageId(messages, parentId)); case 3: return _context10.abrupt("return", nanoid()); case 4: case "end": return _context10.stop(); } }, _callee10); })); function getMessageId(_x12, _x13) { return _getMessageId.apply(this, arguments); } return getMessageId; }(), createSmoothMessage: function createSmoothMessage(id) { var _get11 = get(), dispatchMessage = _get11.dispatchMessage; var buffer = ''; // why use queue: https://shareg.pt/GLBrjpK var outputQueue = []; var mixRequestResponse = {}; // eslint-disable-next-line no-undef var animationTimeoutId = null; var isAnimationActive = false; // when you need to stop the animation, call this function var stopAnimation = function stopAnimation() { isAnimationActive = false; if (animationTimeoutId !== null) { clearTimeout(animationTimeoutId); animationTimeoutId = null; } }; // define startAnimation function to display the text in buffer smooth // when you need to start the animation, call this function var startAnimation = function startAnimation() { var speed = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 2; return new Promise(function (resolve) { if (isAnimationActive) { resolve(); return; } isAnimationActive = true; var updateText = function updateText() { // 如果动画已经不再激活,则停止更新文本 if (!isAnimationActive) { clearTimeout(animationTimeoutId); animationTimeoutId = null; resolve(); } // 如果还有文本没有显示 // 检查队列中是否有字符待显示 if (outputQueue.length > 0) { // 从队列中获取前两个字符(如果存在) var charsToAdd = outputQueue.splice(0, speed).join(''); buffer += charsToAdd; if (_typeof(mixRequestResponse) === 'object' && 'content' in mixRequestResponse) { dispatchMessage(_objectSpread(_objectSpread({}, mixRequestResponse), {}, { id: id, key: 'content', type: 'updateMessage', value: buffer })); } else { // 更新消息内容,这里可能需要结合实际情况调整 dispatchMessage({ id: id, key: 'content', type: 'updateMessage', value: buffer }); } // 设置下一个字符的延迟 animationTimeoutId = setTimeout(updateText, 16); // 16 毫秒的延迟模拟打字机效果 } else { // 当所有字符都显示完毕时,清除动画状态 isAnimationActive = false; animationTimeoutId = null; resolve(); } }; updateText(); }); }; return { startAnimation: startAnimation, stopAnimation: stopAnimation, outputQueue: outputQueue, isAnimationActive: isAnimationActive, mixRequestResponse: mixRequestResponse }; }, getChatLoadingId: function getChatLoadingId() { var _get12 = get(), chatLoadingId = _get12.chatLoadingId; return chatLoadingId; } }; };