UNPKG

@asterai/client

Version:
246 lines (245 loc) 7.88 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.QueryResponse = exports.AsteraiAgent = void 0; const protobufjs_1 = require("protobufjs"); const buffer_1 = require("buffer"); const eventsource_parser_1 = require("eventsource-parser"); const config_1 = require("./config"); const TEXT_DECODER = new TextDecoder(); class AsteraiAgent { constructor(args) { this.apiBaseUrl = config_1.DEFAULT_API_BASE_URL; this.protos = []; this.queryKey = args.queryKey; this.appId = args.appId; if (args.appProtos) { for (let proto of args.appProtos) { this.protos.push((0, protobufjs_1.parse)(proto).root); } } if (args.apiBaseUrl) { this.apiBaseUrl = args.apiBaseUrl; } } /** * Create a new instance of this agent with a different query key. * * Read more about query keys here: * https://docs.asterai.io/querying_an_app.html#query-keys */ withQueryKey(queryKey) { const agent = new AsteraiAgent({ queryKey, appId: this.appId, appProtos: [], apiBaseUrl: this.apiBaseUrl, }); // Copy parsed protos. agent.protos = this.protos; return agent; } ; async query(args) { const abortController = new AbortController(); let url = `${this.apiBaseUrl}/app/${this.appId}/query/sse`; if (args.conversationId) { url += `?conversation_id=${encodeURIComponent(args.conversationId)}`; } const response = await fetch(url, { method: "POST", headers: { "Authorization": `Bearer ${this.queryKey}`, "Accept": "text/event-stream", }, body: args.query, signal: abortController.signal, }); return new QueryResponse(response, abortController, this.protos); } /** * Fetches the agent's summary markdown file, containing * a natural language description of the functions and * plugins available in the agent. * * This is useful context for agent-to-agent communication. */ async fetchSummary() { const url = `${this.apiBaseUrl}/app/${this.appId}/summary.md`; const response = await fetch(url, { method: "GET", }); return response.text(); } } exports.AsteraiAgent = AsteraiAgent; /** * The response object for an agent's query. */ class QueryResponse { constructor(response, abortController, protos) { this.isActive = true; this.onTokenCallbacks = []; this.onPluginOutputCallbacks = []; this.onEndCallbacks = []; this.response = response; this.abortController = abortController; this.protos = protos; this.setupResponse().catch(e => { throw e; }); } async setupResponse() { const parser = (0, eventsource_parser_1.createParser)((event) => { if (event.type !== "event") { return; } const llmTokenPrefix = "llm-token: "; const pluginOutputPrefix = "plugin-output: "; if (event.data.startsWith(llmTokenPrefix)) { const token = event.data.substring(llmTokenPrefix.length); this.callOnToken(token); } else if (event.data.startsWith(pluginOutputPrefix) && this.protos.length > 0) { const rawOutput = event.data.substring(pluginOutputPrefix.length); const output = this.decodePluginOutput(rawOutput); this.callPluginOutput(output); } }); if (!this.response.body) { throw new Error("Response body is null"); } if (!this.response.body.pipeTo) { // In some Node environments, the type of `body` is `PassThrough`. const passThrough = this.response.body; passThrough.addListener("data", (chunk) => { const text = TEXT_DECODER.decode(chunk); parser.feed(text); }); passThrough.addListener("finish", () => { this.callOnEnd({ reason: "finished" }); }); passThrough.addListener("error", () => { this.callOnEnd({ reason: "aborted" }); }); } else { // This is the expected and normal flow, where the type of `body` // is `ReadableStream`. await this.response.body.pipeTo(new WritableStream({ write: (chunk) => { const text = TEXT_DECODER.decode(chunk); parser.feed(text); }, close: () => { this.callOnEnd({ reason: "finished" }); }, abort: () => { this.callOnEnd({ reason: "aborted" }); } })); } } onToken(callback) { if (!this.isActive) { return; } this.onTokenCallbacks.push(callback); } onPluginOutput(callback) { if (!this.isActive) { return; } this.onPluginOutputCallbacks.push(callback); } onEnd(callback) { if (!this.isActive) { return; } this.onEndCallbacks.push(callback); } // Returns the full LLM response as text. text() { let response = ""; this.onToken(token => { response += token; }); return new Promise((resolve, reject) => { this.onEnd(state => { if (state.reason === "aborted") { reject(); return; } resolve(response); }); }); } callOnToken(token) { if (!this.isActive) { return; } for (const callback of this.onTokenCallbacks) { callback(token); } } callPluginOutput(args) { if (!this.isActive) { return; } for (const callback of this.onPluginOutputCallbacks) { callback(args); } } callOnEnd(state) { if (!this.isActive) { return; } this.isActive = false; for (const callback of this.onEndCallbacks) { callback(state); } } abort() { if (this.abortController.signal.aborted) { return; } this.abortController.abort(); this.callOnEnd({ reason: "aborted" }); } decodePluginOutput(rawOutput) { const splitMessage = rawOutput.split(":"); const messageTypeName = splitMessage[0]; const encodedMessage = splitMessage[1].slice(1); let messageType; for (let proto of this.protos) { try { messageType = proto.lookupType(messageTypeName); if (messageType) { break; } } catch (error) { } } if (!messageType) { throw new Error(`no message type found for ${messageTypeName}`); } try { const buffer = buffer_1.Buffer.from(encodedMessage, "base64"); const decodedMessage = messageType.decode(buffer); const value = messageType.toObject(decodedMessage, { longs: String, enums: String, bytes: String, }); return { name: messageTypeName, value }; } catch (error) { throw new Error(`failed to decode asterai plugin message: ${error}`); } } } exports.QueryResponse = QueryResponse;