UNPKG

@statelyai/agent

Version:

Stateful agents that make decisions based on finite-state machine models

622 lines (609 loc) 17.9 kB
"use strict"; var __create = Object.create; var __defProp = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames = Object.getOwnPropertyNames; var __getProtoOf = Object.getPrototypeOf; var __hasOwnProp = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames(from)) if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__getProtoOf(mod)) : {}, __copyProps( // If the importer is in node compatibility mode or this is not an ESM // file that has been converted to a CommonJS file using a Babel- // compatible transform (i.e. "__esModule" has not been set), then set // "default" to the CommonJS "module.exports" for node compatibility. isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target, mod )); var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod); // src/index.ts var src_exports = {}; __export(src_exports, { createAgent: () => createAgent, fromDecision: () => fromDecision, fromText: () => fromText, fromTextStream: () => fromTextStream }); module.exports = __toCommonJS(src_exports); // src/agent.ts var import_xstate3 = require("xstate"); // src/planners/simplePlanner.ts var import_ai = require("ai"); // src/utils.ts var import_object_hash = __toESM(require("object-hash")); function getAllTransitions(state) { const nodes = state._nodes; const transitions = nodes.map((node) => [...node.transitions.values()]).map((nodeTransitions) => { return nodeTransitions.map((nodeEventTransitions) => { return nodeEventTransitions.map((transition) => { return { ...transition, guard: typeof transition.guard === "string" ? { type: transition.guard } : transition.guard // TODO: fix }; }); }); }).flat(2); return transitions; } function getAllMachineTransitions(stateNode) { const transitions = [...stateNode.transitions.values()].map((nodeTransitions) => { return nodeTransitions.map((transition) => { return { ...transition, guard: typeof transition.guard === "string" ? { type: transition.guard } : transition.guard // TODO: fix }; }); }).flat(2); for (const s of Object.values(stateNode.states)) { const stateTransitions = getAllMachineTransitions(s); transitions.push(...stateTransitions); } return transitions; } function wrapInXml(tagName, content) { return `<${tagName}>${content}</${tagName}>`; } function randomId() { const timestamp = Date.now().toString(36); const random = Math.random().toString(36).substring(2, 9); return timestamp + random; } var machineHashes = /* @__PURE__ */ new WeakMap(); function getMachineHash(machine) { if (machineHashes.has(machine)) return machineHashes.get(machine); const transitions = getAllMachineTransitions(machine.root); const machineHash = (0, import_object_hash.default)(transitions); machineHashes.set(machine, machineHash); return machineHash; } // src/templates/defaultText.ts var defaultTextTemplate = (data) => { const preamble = [ data.context ? wrapInXml("context", JSON.stringify(data.context)) : void 0 ].filter(Boolean).join("\n"); return ` ${preamble} ${data.goal} `.trim(); }; // src/text.ts var import_xstate = require("xstate"); async function getMessages(agent, prompt, options) { let messages = []; if (typeof options.messages === "function") { messages = await options.messages(agent); } else if (options.messages) { messages = options.messages; } messages = messages.concat({ role: "user", content: prompt }); return messages; } async function agentGenerateText(agent, options) { const resolvedOptions = { ...agent.defaultOptions, ...options, correlationId: options.correlationId ?? randomId() }; const template = resolvedOptions.template ?? defaultTextTemplate; const id = randomId(); const goal = typeof resolvedOptions.prompt === "string" ? resolvedOptions.prompt : await resolvedOptions.prompt(agent); const promptWithContext = template({ goal, context: resolvedOptions.context }); const messages = await getMessages(agent, promptWithContext, resolvedOptions); agent.addMessage({ id, role: "user", content: promptWithContext, timestamp: Date.now(), correlationId: resolvedOptions.correlationId, parentCorrelationId: resolvedOptions.parentCorrelationId }); const result = await agent.adapter.generateText({ ...resolvedOptions, prompt: void 0, messages }); agent.addMessage({ content: result.text, id, role: "assistant", timestamp: Date.now(), responseId: id, result, correlationId: resolvedOptions.correlationId, parentCorrelationId: resolvedOptions.parentCorrelationId }); return { ...result, parentCorrelationId: resolvedOptions.parentCorrelationId, correlationId: resolvedOptions.correlationId }; } async function agentStreamText(agent, options) { const resolvedOptions = { ...agent.defaultOptions, ...options, correlationId: options.correlationId ?? randomId() }; const template = resolvedOptions.template ?? defaultTextTemplate; const id = randomId(); const goal = typeof resolvedOptions.prompt === "string" ? resolvedOptions.prompt : await resolvedOptions.prompt(agent); const promptWithContext = template({ goal, context: resolvedOptions.context }); const messages = await getMessages(agent, promptWithContext, resolvedOptions); agent.addMessage({ role: "user", content: promptWithContext, id, timestamp: Date.now(), correlationId: resolvedOptions.correlationId, parentCorrelationId: resolvedOptions.parentCorrelationId }); const result = await agent.adapter.streamText({ ...resolvedOptions, prompt: void 0, messages, onFinish: async (res) => { agent.addMessage({ role: "assistant", result: { text: res.text, finishReason: res.finishReason, logprobs: void 0, responseMessages: [], toolCalls: [], toolResults: [], usage: res.usage, warnings: res.warnings, rawResponse: res.rawResponse, roundtrips: [], // TODO: how do we get this information?, steps: res.steps, response: res.response, experimental_providerMetadata: res.experimental_providerMetadata }, content: res.text, id: randomId(), timestamp: Date.now(), responseId: id, correlationId: resolvedOptions.correlationId, parentCorrelationId: resolvedOptions.parentCorrelationId }); } }); return { ...result, textStream: result.textStream, fullStream: result.fullStream, parentCorrelationId: resolvedOptions.parentCorrelationId, correlationId: resolvedOptions.correlationId }; } function fromTextStream(agent, defaultOptions) { return (0, import_xstate.fromObservable)(({ input }) => { const observers = /* @__PURE__ */ new Set(); (async () => { const result = await agentStreamText(agent, { ...defaultOptions, ...input, context: input.context }); for await (const part of result.fullStream) { if (part.type === "text-delta") { observers.forEach((observer) => { observer.next?.(part); }); } } })(); return { subscribe: (...args) => { const observer = (0, import_xstate.toObserver)(...args); observers.add(observer); return { unsubscribe: () => { observers.delete(observer); } }; } }; }); } function fromText(agent, defaultOptions) { return (0, import_xstate.fromPromise)(async ({ input }) => { return await agentGenerateText(agent, { ...input, ...defaultOptions, context: input.context }); }); } // src/planners/simplePlanner.ts function getTransitions(state, machine) { if (!machine) { return []; } const resolvedState = machine.resolveState(state); return getAllTransitions(resolvedState); } var simplePlannerPromptTemplate = (data) => { return ` ${defaultTextTemplate(data)} Make at most one tool call to achieve the above goal. If the goal cannot be achieved with any tool calls, do not make any tool call. `.trim(); }; async function simplePlanner(agent, input) { const transitions = input.machine ? getTransitions(input.state, input.machine) : Object.entries(input.events).map(([eventType, { description }]) => ({ eventType, description })); const filter = (eventType) => Object.keys(input.events).includes(eventType); const functionNameMapping = {}; const toolTransitions = transitions.filter((t) => { return filter(t.eventType); }).map((t) => { const name = t.eventType.replace(/\./g, "_"); functionNameMapping[name] = t.eventType; return { type: "function", eventType: t.eventType, description: t.description, name }; }); const toolMap = {}; for (const toolTransitionData of toolTransitions) { const toolZodType = input.events?.[toolTransitionData.eventType]; if (!toolZodType) { continue; } toolMap[toolTransitionData.name] = (0, import_ai.tool)({ description: toolZodType?.description ?? toolTransitionData.description, parameters: toolZodType, execute: async (params) => { const event = { type: toolTransitionData.eventType, ...params }; return event; } }); } if (!Object.keys(toolMap).length) { return void 0; } const prompt = simplePlannerPromptTemplate({ context: input.state.context, goal: input.goal }); const messages = await getMessages(agent, prompt, input); const result = await agent.generateText({ toolChoice: "required", ...input, prompt, messages, tools: toolMap }); const singleResult = result.toolResults[0]; if (!singleResult) { console.warn("No tool call results returned"); return void 0; } return { goal: input.goal, state: input.state, execute: async (state) => { if (JSON.stringify(state) === JSON.stringify(input.state)) { return singleResult.result; } return void 0; }, nextEvent: singleResult.result, sessionId: agent.sessionId, timestamp: Date.now() }; } // src/decision.ts var import_xstate2 = require("xstate"); async function agentDecide(agent, options) { const resolvedOptions = { ...agent.defaultOptions, ...options }; const { planner = simplePlanner, goal, events = agent.events, state, machine, model = agent.model, ...otherPlanInput } = resolvedOptions; const plan = await planner(agent, { model, goal, events, state, machine, ...otherPlanInput }); if (plan?.nextEvent) { agent.addPlan(plan); await resolvedOptions.execute?.(plan.nextEvent); } return plan; } function fromDecision(agent, defaultInput) { return (0, import_xstate2.fromPromise)(async ({ input, self }) => { const parentRef = self._parent; if (!parentRef) { return; } const snapshot = parentRef.getSnapshot(); const inputObject = typeof input === "string" ? { goal: input } : input; const resolvedInput = { ...defaultInput, ...inputObject }; const contextToInclude = resolvedInput.context === true ? ( // include entire context parentRef.getSnapshot().context ) : resolvedInput.context; const state = { value: snapshot.value, context: contextToInclude }; const plan = await agentDecide(agent, { machine: parentRef.logic, state, execute: async (event) => { parentRef.send(event); }, ...resolvedInput }); return plan; }); } // src/adapters/vercel.ts var import_ai2 = require("ai"); var vercelAdapter = { generateText: import_ai2.generateText, streamText: import_ai2.streamText }; // src/agent.ts var agentLogic = (0, import_xstate3.fromTransition)( (state, event, { emit }) => { switch (event.type) { case "agent.feedback": { state.feedback.push(event.feedback); emit({ type: "feedback", // @ts-ignore TODO: fix types in XState feedback: event.feedback }); break; } case "agent.observe": { state.observations.push(event.observation); emit({ type: "observation", // @ts-ignore TODO: fix types in XState observation: event.observation }); break; } case "agent.message": { state.messages.push(event.message); emit({ type: "message", // @ts-ignore TODO: fix types in XState message: event.message }); break; } case "agent.plan": { state.plans.push(event.plan); emit({ type: "plan", // @ts-ignore TODO: fix types in XState plan: event.plan }); break; } default: break; } return state; }, () => ({ feedback: [], messages: [], observations: [], plans: [] }) ); function createAgent({ name, description, model, events, context, planner = simplePlanner, stringify = JSON.stringify, getMemory, logic = agentLogic, adapter = vercelAdapter, ...generateTextOptions }) { const agent = (0, import_xstate3.createActor)(logic); agent.events = events; agent.model = model; agent.name = name; agent.description = description; agent.adapter = adapter; agent.defaultOptions = { ...generateTextOptions, model }; agent.select = (selector) => { return selector(agent.getSnapshot().context); }; agent.memory = getMemory ? getMemory(agent) : void 0; agent.onMessage = (callback) => { agent.on("message", (ev) => callback(ev.message)); }; agent.decide = (opts) => { return agentDecide(agent, opts); }; agent.addMessage = (messageInput) => { const message = { ...messageInput, id: messageInput.id ?? randomId(), timestamp: messageInput.timestamp ?? Date.now(), sessionId: agent.sessionId, correlationId: messageInput.correlationId ?? randomId() }; agent.send({ type: "agent.message", message }); return message; }; agent.getMessages = () => agent.getSnapshot().context.messages; agent.generateText = (opts) => agentGenerateText(agent, opts); agent.streamText = (opts) => agentStreamText(agent, opts); agent.addFeedback = (feedbackInput) => { const feedback = { ...feedbackInput, attributes: { ...feedbackInput.attributes }, reward: feedbackInput.reward ?? 0, timestamp: feedbackInput.timestamp ?? Date.now(), sessionId: agent.sessionId }; agent.send({ type: "agent.feedback", feedback }); return feedback; }; agent.getFeedback = () => agent.getSnapshot().context.feedback; agent.addObservation = (observationInput) => { const { prevState, event, state } = observationInput; const observation = { prevState, event, state, id: observationInput.id ?? randomId(), sessionId: agent.sessionId, timestamp: observationInput.timestamp ?? Date.now(), machineHash: observationInput.machine ? getMachineHash(observationInput.machine) : void 0 }; agent.send({ type: "agent.observe", observation }); return observation; }; agent.getObservations = () => agent.getSnapshot().context.observations; agent.addPlan = (plan) => { agent.send({ type: "agent.plan", plan }); }; agent.getPlans = () => agent.getSnapshot().context.plans; agent.interact = (actorRef, getInput) => { let prevState = void 0; let subscribed = true; async function handleObservation(observationInput) { const observation = agent.addObservation(observationInput); const input = getInput?.(observation); if (input) { await agentDecide(agent, { machine: actorRef.src, state: observation.state, execute: async (event) => { actorRef.send(event); }, ...input }); } prevState = observationInput.state; } actorRef.system.inspect({ next: async (inspEvent) => { if (!subscribed || inspEvent.actorRef !== actorRef || inspEvent.type !== "@xstate.snapshot") { return; } const observationInput = { event: inspEvent.event, prevState, state: inspEvent.snapshot, machine: actorRef.src }; await handleObservation(observationInput); } }); if (actorRef._processingStatus === 1) { handleObservation({ prevState: void 0, event: { type: "" }, // TODO: unknown events? state: actorRef.getSnapshot(), machine: actorRef.src }); } return { unsubscribe: () => { subscribed = false; } // TODO: make this actually unsubscribe }; }; agent.types = {}; agent.start(); return agent; } // Annotate the CommonJS export names for ESM import in node: 0 && (module.exports = { createAgent, fromDecision, fromText, fromTextStream });