UNPKG

@axflow/models

Version:

Zero-dependency, modular SDK for building robust natural language applications

410 lines (407 loc) 12.7 kB
"use strict"; var __defProp = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames = Object.getOwnPropertyNames; var __hasOwnProp = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames(from)) if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod); // src/react/index.ts var react_exports = {}; __export(react_exports, { useChat: () => useChat }); module.exports = __toCommonJS(react_exports); // src/react/use-chat.ts var import_react = require("react"); var import_shared = require("@axflow/models/shared"); function uuidv4() { return crypto.randomUUID(); } var mergeToolCallIntoMessage = (partialChunkToolCall, msg, content) => { const msgContainsChunkTool = (msg.toolCalls || []).some( (tool) => tool.index === partialChunkToolCall.index ); if (!msgContainsChunkTool) { return { ...msg, content, toolCalls: [...msg.toolCalls || [], (0, import_shared.toolCallWithDefaults)(partialChunkToolCall)] }; } else { return { ...msg, toolCalls: msg.toolCalls.map((toolCall) => { if (toolCall.index !== partialChunkToolCall.index) { return toolCall; } else { return { ...toolCall, ...partialChunkToolCall, function: { ...toolCall.function, ...partialChunkToolCall.function, arguments: (toolCall.function.arguments || "") + (partialChunkToolCall.function?.arguments || "") } }; } }) }; } }; async function handleStreamingResponse(response, messagesRef, setMessages, accessor, functionCallAccessor, toolCallsAccessor, onNewMessage, createMessageId) { const responseBody = response.body; if (responseBody === null) { throw new import_shared.HttpError( "Expected response.body to be a stream but response.body is null", response ); } let id = null; for await (const chunk of (0, import_shared.StreamToIterable)(import_shared.NdJsonStream.decode(responseBody))) { let messages = messagesRef.current; if (chunk.type !== "chunk") { if (!id) { id = createMessageId(); messages = messages.concat({ id, role: "assistant", data: [chunk.value], content: "", created: Date.now() }); } else { messages = messages.map((msg) => { return msg.id !== id ? msg : { ...msg, data: (msg.data || []).concat(chunk.value) }; }); } } else { const chunkContent = accessor(chunk.value); const chunkFunctionCall = functionCallAccessor(chunk.value); const chunkToolCalls = toolCallsAccessor(chunk.value); if (!id) { id = createMessageId(); const message = { id, role: "assistant", content: chunkContent ?? "", created: Date.now() }; if (chunkFunctionCall) { message.functionCall = { name: chunkFunctionCall.name ?? "", arguments: chunkFunctionCall.arguments ?? "" }; } if (chunkToolCalls) { message.toolCalls = chunkToolCalls.map(import_shared.toolCallWithDefaults); } messages = messages.concat(message); } else { messages = messages.map((msg) => { if (msg.id !== id) { return msg; } const content = msg.content + (chunkContent ?? ""); if (chunkFunctionCall) { const name = msg.functionCall.name + (chunkFunctionCall.name ?? ""); const args = msg.functionCall.arguments + (chunkFunctionCall.arguments ?? ""); return { ...msg, content, functionCall: { name, arguments: args } }; } else if (chunkToolCalls) { for (const chunkToolCall of chunkToolCalls) { msg = mergeToolCallIntoMessage(chunkToolCall, msg, content); } return msg; } else { return { ...msg, content }; } }); } } setMessages(messages); } const newMessage = messagesRef.current.find((msg) => msg.id === id); onNewMessage(newMessage); } async function handleJsonResponse(response, messagesRef, setMessages, accessor, functionCallAccessor, toolCallsAccessor, onNewMessage, createMessageId) { const responseBody = await response.json(); const content = accessor(responseBody); const functionCall = functionCallAccessor(responseBody); const toolCalls = toolCallsAccessor(responseBody); const newMessage = { id: createMessageId(), role: "assistant", content: content ?? "", created: Date.now() }; if (functionCall) { newMessage.functionCall = { name: functionCall.name ?? "", arguments: functionCall.arguments ?? "" }; } if (toolCalls) { newMessage.toolCalls = toolCalls.map(import_shared.toolCallWithDefaults); } const messages = messagesRef.current.concat(newMessage); setMessages(messages); onNewMessage(newMessage); } async function request(prepare, messagesRef, setMessages, url, headers, accessor, functionCallAccessor, toolCallsAccessor, loadingRef, setLoading, setError, onError, onNewMessage, onSuccess, createMessageId) { if (loadingRef.current) { return; } setLoading(true); setError(null); const requestBody = prepare(); let response; try { response = await (0, import_shared.POST)(url, { headers: { ...headers, "content-type": "application/json; charset=utf-8" }, body: JSON.stringify(requestBody) }); const contentType = response.headers.get("content-type") || ""; const isStreaming = contentType.toLowerCase() === "application/x-ndjson; charset=utf-8"; const handler = isStreaming ? handleStreamingResponse : handleJsonResponse; await handler( response, messagesRef, setMessages, accessor, functionCallAccessor, toolCallsAccessor, onNewMessage, createMessageId ); onSuccess(); } catch (error) { setError(error); onError(error); } finally { setLoading(false); } } async function stableAppend(message, messagesRef, setMessages, url, headers, body, accessor, functionCallAccessor, toolCallsAccessor, loadingRef, setLoading, setError, onError, onNewMessage, setFunctions, setTools, createMessageId) { function prepare() { const history = messagesRef.current; const requestBody = typeof body === "function" ? body(message, history) : { ...body, messages: history.concat(message) }; setMessages(history.concat(message)); onNewMessage(message); return requestBody; } return request( prepare, messagesRef, setMessages, url, headers, accessor, functionCallAccessor, toolCallsAccessor, loadingRef, setLoading, setError, onError, onNewMessage, () => { setFunctions([]); setTools([]); }, // Clear functions after each request (similar to clearing user input) createMessageId ); } async function stableReload(messagesRef, setMessages, url, headers, body, accessor, functionCallAccessor, toolCallsAccessor, loadingRef, setLoading, setError, onError, onNewMessage, createMessageId) { function prepare() { const messages = messagesRef.current; const history = []; let lastMessage = null; for (let i = messages.length - 1; i >= 0; i--) { const msg = messages[i]; const role = msg.role; if (lastMessage === null && (role === "user" || role === "system")) { lastMessage = msg; } else if (lastMessage !== null) { history.unshift(msg); } } if (lastMessage === null) { throw new Error("Cannot reload empty conversation"); } const requestBody = typeof body === "function" ? body(lastMessage, history) : { ...body, messages: history.concat(lastMessage) }; if (messages[messages.length - 1].id !== lastMessage.id) { setMessages(history.concat(lastMessage)); } return requestBody; } return request( prepare, messagesRef, setMessages, url, headers, accessor, functionCallAccessor, toolCallsAccessor, loadingRef, setLoading, setError, onError, onNewMessage, () => { }, createMessageId ); } var DEFAULT_URL = "/api/chat"; var DEFAULT_CREATE_MESSAGE_ID = uuidv4; var DEFAULT_ACCESSOR = (value) => { return typeof value === "string" ? value : void 0; }; var DEFAULT_FUNCTION_CALL_ACCESSOR = (_value) => { return void 0; }; var DEFAULT_TOOL_CALLS_ACCESSOR = (_value) => { return void 0; }; var DEFAULT_BODY = (message, history) => ({ messages: [...history, message] }); var DEFAULT_HEADERS = {}; var DEFAULT_ON_ERROR = (error) => { console.error(error); }; var DEFAULT_ON_MESSAGES_CHANGE = (_messages) => { }; var DEFAULT_ON_NEW_MESSAGE = (_message) => { }; function useChat(options) { options ??= {}; const initialInput = options.initialInput ?? ""; const [input, setInput] = (0, import_react.useState)(initialInput); const initialMessages = options.initialMessages ?? []; const [messages, _setMessages] = (0, import_react.useState)(initialMessages); const messagesRef = (0, import_react.useRef)(initialMessages); const initialFunctions = options.initialFunctions ?? []; const [functions, setFunctions] = (0, import_react.useState)(initialFunctions); const initialTools = options.initialTools ?? []; const [tools, setTools] = (0, import_react.useState)(initialTools); const [loading, _setLoading] = (0, import_react.useState)(false); const loadingRef = (0, import_react.useRef)(false); const [error, setError] = (0, import_react.useState)(null); const url = options.url ?? DEFAULT_URL; const createMessageId = options.createMessageId ?? DEFAULT_CREATE_MESSAGE_ID; const accessor = options.accessor ?? DEFAULT_ACCESSOR; const functionCallAccessor = options.functionCallAccessor ?? DEFAULT_FUNCTION_CALL_ACCESSOR; const toolCallsAccessor = options.toolCallsAccessor ?? DEFAULT_TOOL_CALLS_ACCESSOR; const body = options.body ?? DEFAULT_BODY; const headers = options.headers ?? DEFAULT_HEADERS; const onError = options.onError ?? DEFAULT_ON_ERROR; const onMessagesChange = options.onMessagesChange ?? DEFAULT_ON_MESSAGES_CHANGE; const onNewMessage = options.onNewMessage ?? DEFAULT_ON_NEW_MESSAGE; const setMessages = (0, import_react.useCallback)( (messages2) => { _setMessages(messages2); messagesRef.current = messages2; onMessagesChange(messages2); }, [messagesRef, _setMessages, onMessagesChange] ); const setLoading = (0, import_react.useCallback)( (loading2) => { _setLoading(loading2); loadingRef.current = loading2; }, [loadingRef, _setLoading] ); function onChange(e) { if (typeof e === "string") { setInput(e); } else { setInput(e.target.value); } } function onSubmit(e) { if (e) { e.preventDefault(); } const newMessage = { id: createMessageId(), role: "user", content: input, created: Date.now() }; if (functions.length > 0) { newMessage.functions = functions; } if (tools.length > 0) { newMessage.tools = tools; } stableAppend( newMessage, messagesRef, setMessages, url, headers, body, accessor, functionCallAccessor, toolCallsAccessor, loadingRef, setLoading, setError, onError, onNewMessage, setFunctions, setTools, createMessageId ); setInput(""); } function reload() { stableReload( messagesRef, setMessages, url, headers, body, accessor, functionCallAccessor, toolCallsAccessor, loadingRef, setLoading, setError, onError, onNewMessage, createMessageId ); } return { input, setInput, messages, setMessages, functions, setFunctions, setTools, loading, error, onChange, onSubmit, reload }; } // Annotate the CommonJS export names for ESM import in node: 0 && (module.exports = { useChat });