@statelyai/agent
Version:
Stateful agents that make decisions based on finite-state machine models
622 lines (609 loc) • 17.9 kB
JavaScript
;
var __create = Object.create;
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __getProtoOf = Object.getPrototypeOf;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name in all)
__defProp(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__getProtoOf(mod)) : {}, __copyProps(
// If the importer is in node compatibility mode or this is not an ESM
// file that has been converted to a CommonJS file using a Babel-
// compatible transform (i.e. "__esModule" has not been set), then set
// "default" to the CommonJS "module.exports" for node compatibility.
isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target,
mod
));
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
// src/index.ts
var src_exports = {};
__export(src_exports, {
createAgent: () => createAgent,
fromDecision: () => fromDecision,
fromText: () => fromText,
fromTextStream: () => fromTextStream
});
module.exports = __toCommonJS(src_exports);
// src/agent.ts
var import_xstate3 = require("xstate");
// src/planners/simplePlanner.ts
var import_ai = require("ai");
// src/utils.ts
var import_object_hash = __toESM(require("object-hash"));
function getAllTransitions(state) {
const nodes = state._nodes;
const transitions = nodes.map((node) => [...node.transitions.values()]).map((nodeTransitions) => {
return nodeTransitions.map((nodeEventTransitions) => {
return nodeEventTransitions.map((transition) => {
return {
...transition,
guard: typeof transition.guard === "string" ? { type: transition.guard } : transition.guard
// TODO: fix
};
});
});
}).flat(2);
return transitions;
}
function getAllMachineTransitions(stateNode) {
const transitions = [...stateNode.transitions.values()].map((nodeTransitions) => {
return nodeTransitions.map((transition) => {
return {
...transition,
guard: typeof transition.guard === "string" ? { type: transition.guard } : transition.guard
// TODO: fix
};
});
}).flat(2);
for (const s of Object.values(stateNode.states)) {
const stateTransitions = getAllMachineTransitions(s);
transitions.push(...stateTransitions);
}
return transitions;
}
function wrapInXml(tagName, content) {
return `<${tagName}>${content}</${tagName}>`;
}
function randomId() {
const timestamp = Date.now().toString(36);
const random = Math.random().toString(36).substring(2, 9);
return timestamp + random;
}
var machineHashes = /* @__PURE__ */ new WeakMap();
function getMachineHash(machine) {
if (machineHashes.has(machine)) return machineHashes.get(machine);
const transitions = getAllMachineTransitions(machine.root);
const machineHash = (0, import_object_hash.default)(transitions);
machineHashes.set(machine, machineHash);
return machineHash;
}
// src/templates/defaultText.ts
var defaultTextTemplate = (data) => {
const preamble = [
data.context ? wrapInXml("context", JSON.stringify(data.context)) : void 0
].filter(Boolean).join("\n");
return `
${preamble}
${data.goal}
`.trim();
};
// src/text.ts
var import_xstate = require("xstate");
async function getMessages(agent, prompt, options) {
let messages = [];
if (typeof options.messages === "function") {
messages = await options.messages(agent);
} else if (options.messages) {
messages = options.messages;
}
messages = messages.concat({
role: "user",
content: prompt
});
return messages;
}
async function agentGenerateText(agent, options) {
const resolvedOptions = {
...agent.defaultOptions,
...options,
correlationId: options.correlationId ?? randomId()
};
const template = resolvedOptions.template ?? defaultTextTemplate;
const id = randomId();
const goal = typeof resolvedOptions.prompt === "string" ? resolvedOptions.prompt : await resolvedOptions.prompt(agent);
const promptWithContext = template({
goal,
context: resolvedOptions.context
});
const messages = await getMessages(agent, promptWithContext, resolvedOptions);
agent.addMessage({
id,
role: "user",
content: promptWithContext,
timestamp: Date.now(),
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId
});
const result = await agent.adapter.generateText({
...resolvedOptions,
prompt: void 0,
messages
});
agent.addMessage({
content: result.text,
id,
role: "assistant",
timestamp: Date.now(),
responseId: id,
result,
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId
});
return {
...result,
parentCorrelationId: resolvedOptions.parentCorrelationId,
correlationId: resolvedOptions.correlationId
};
}
async function agentStreamText(agent, options) {
const resolvedOptions = {
...agent.defaultOptions,
...options,
correlationId: options.correlationId ?? randomId()
};
const template = resolvedOptions.template ?? defaultTextTemplate;
const id = randomId();
const goal = typeof resolvedOptions.prompt === "string" ? resolvedOptions.prompt : await resolvedOptions.prompt(agent);
const promptWithContext = template({
goal,
context: resolvedOptions.context
});
const messages = await getMessages(agent, promptWithContext, resolvedOptions);
agent.addMessage({
role: "user",
content: promptWithContext,
id,
timestamp: Date.now(),
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId
});
const result = await agent.adapter.streamText({
...resolvedOptions,
prompt: void 0,
messages,
onFinish: async (res) => {
agent.addMessage({
role: "assistant",
result: {
text: res.text,
finishReason: res.finishReason,
logprobs: void 0,
responseMessages: [],
toolCalls: [],
toolResults: [],
usage: res.usage,
warnings: res.warnings,
rawResponse: res.rawResponse,
roundtrips: [],
// TODO: how do we get this information?,
steps: res.steps,
response: res.response,
experimental_providerMetadata: res.experimental_providerMetadata
},
content: res.text,
id: randomId(),
timestamp: Date.now(),
responseId: id,
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId
});
}
});
return {
...result,
textStream: result.textStream,
fullStream: result.fullStream,
parentCorrelationId: resolvedOptions.parentCorrelationId,
correlationId: resolvedOptions.correlationId
};
}
function fromTextStream(agent, defaultOptions) {
return (0, import_xstate.fromObservable)(({ input }) => {
const observers = /* @__PURE__ */ new Set();
(async () => {
const result = await agentStreamText(agent, {
...defaultOptions,
...input,
context: input.context
});
for await (const part of result.fullStream) {
if (part.type === "text-delta") {
observers.forEach((observer) => {
observer.next?.(part);
});
}
}
})();
return {
subscribe: (...args) => {
const observer = (0, import_xstate.toObserver)(...args);
observers.add(observer);
return {
unsubscribe: () => {
observers.delete(observer);
}
};
}
};
});
}
function fromText(agent, defaultOptions) {
return (0, import_xstate.fromPromise)(async ({ input }) => {
return await agentGenerateText(agent, {
...input,
...defaultOptions,
context: input.context
});
});
}
// src/planners/simplePlanner.ts
function getTransitions(state, machine) {
if (!machine) {
return [];
}
const resolvedState = machine.resolveState(state);
return getAllTransitions(resolvedState);
}
var simplePlannerPromptTemplate = (data) => {
return `
${defaultTextTemplate(data)}
Make at most one tool call to achieve the above goal. If the goal cannot be achieved with any tool calls, do not make any tool call.
`.trim();
};
async function simplePlanner(agent, input) {
const transitions = input.machine ? getTransitions(input.state, input.machine) : Object.entries(input.events).map(([eventType, { description }]) => ({
eventType,
description
}));
const filter = (eventType) => Object.keys(input.events).includes(eventType);
const functionNameMapping = {};
const toolTransitions = transitions.filter((t) => {
return filter(t.eventType);
}).map((t) => {
const name = t.eventType.replace(/\./g, "_");
functionNameMapping[name] = t.eventType;
return {
type: "function",
eventType: t.eventType,
description: t.description,
name
};
});
const toolMap = {};
for (const toolTransitionData of toolTransitions) {
const toolZodType = input.events?.[toolTransitionData.eventType];
if (!toolZodType) {
continue;
}
toolMap[toolTransitionData.name] = (0, import_ai.tool)({
description: toolZodType?.description ?? toolTransitionData.description,
parameters: toolZodType,
execute: async (params) => {
const event = {
type: toolTransitionData.eventType,
...params
};
return event;
}
});
}
if (!Object.keys(toolMap).length) {
return void 0;
}
const prompt = simplePlannerPromptTemplate({
context: input.state.context,
goal: input.goal
});
const messages = await getMessages(agent, prompt, input);
const result = await agent.generateText({
toolChoice: "required",
...input,
prompt,
messages,
tools: toolMap
});
const singleResult = result.toolResults[0];
if (!singleResult) {
console.warn("No tool call results returned");
return void 0;
}
return {
goal: input.goal,
state: input.state,
execute: async (state) => {
if (JSON.stringify(state) === JSON.stringify(input.state)) {
return singleResult.result;
}
return void 0;
},
nextEvent: singleResult.result,
sessionId: agent.sessionId,
timestamp: Date.now()
};
}
// src/decision.ts
var import_xstate2 = require("xstate");
async function agentDecide(agent, options) {
const resolvedOptions = {
...agent.defaultOptions,
...options
};
const {
planner = simplePlanner,
goal,
events = agent.events,
state,
machine,
model = agent.model,
...otherPlanInput
} = resolvedOptions;
const plan = await planner(agent, {
model,
goal,
events,
state,
machine,
...otherPlanInput
});
if (plan?.nextEvent) {
agent.addPlan(plan);
await resolvedOptions.execute?.(plan.nextEvent);
}
return plan;
}
function fromDecision(agent, defaultInput) {
return (0, import_xstate2.fromPromise)(async ({ input, self }) => {
const parentRef = self._parent;
if (!parentRef) {
return;
}
const snapshot = parentRef.getSnapshot();
const inputObject = typeof input === "string" ? { goal: input } : input;
const resolvedInput = {
...defaultInput,
...inputObject
};
const contextToInclude = resolvedInput.context === true ? (
// include entire context
parentRef.getSnapshot().context
) : resolvedInput.context;
const state = {
value: snapshot.value,
context: contextToInclude
};
const plan = await agentDecide(agent, {
machine: parentRef.logic,
state,
execute: async (event) => {
parentRef.send(event);
},
...resolvedInput
});
return plan;
});
}
// src/adapters/vercel.ts
var import_ai2 = require("ai");
var vercelAdapter = {
generateText: import_ai2.generateText,
streamText: import_ai2.streamText
};
// src/agent.ts
var agentLogic = (0, import_xstate3.fromTransition)(
(state, event, { emit }) => {
switch (event.type) {
case "agent.feedback": {
state.feedback.push(event.feedback);
emit({
type: "feedback",
// @ts-ignore TODO: fix types in XState
feedback: event.feedback
});
break;
}
case "agent.observe": {
state.observations.push(event.observation);
emit({
type: "observation",
// @ts-ignore TODO: fix types in XState
observation: event.observation
});
break;
}
case "agent.message": {
state.messages.push(event.message);
emit({
type: "message",
// @ts-ignore TODO: fix types in XState
message: event.message
});
break;
}
case "agent.plan": {
state.plans.push(event.plan);
emit({
type: "plan",
// @ts-ignore TODO: fix types in XState
plan: event.plan
});
break;
}
default:
break;
}
return state;
},
() => ({
feedback: [],
messages: [],
observations: [],
plans: []
})
);
function createAgent({
name,
description,
model,
events,
context,
planner = simplePlanner,
stringify = JSON.stringify,
getMemory,
logic = agentLogic,
adapter = vercelAdapter,
...generateTextOptions
}) {
const agent = (0, import_xstate3.createActor)(logic);
agent.events = events;
agent.model = model;
agent.name = name;
agent.description = description;
agent.adapter = adapter;
agent.defaultOptions = { ...generateTextOptions, model };
agent.select = (selector) => {
return selector(agent.getSnapshot().context);
};
agent.memory = getMemory ? getMemory(agent) : void 0;
agent.onMessage = (callback) => {
agent.on("message", (ev) => callback(ev.message));
};
agent.decide = (opts) => {
return agentDecide(agent, opts);
};
agent.addMessage = (messageInput) => {
const message = {
...messageInput,
id: messageInput.id ?? randomId(),
timestamp: messageInput.timestamp ?? Date.now(),
sessionId: agent.sessionId,
correlationId: messageInput.correlationId ?? randomId()
};
agent.send({
type: "agent.message",
message
});
return message;
};
agent.getMessages = () => agent.getSnapshot().context.messages;
agent.generateText = (opts) => agentGenerateText(agent, opts);
agent.streamText = (opts) => agentStreamText(agent, opts);
agent.addFeedback = (feedbackInput) => {
const feedback = {
...feedbackInput,
attributes: { ...feedbackInput.attributes },
reward: feedbackInput.reward ?? 0,
timestamp: feedbackInput.timestamp ?? Date.now(),
sessionId: agent.sessionId
};
agent.send({
type: "agent.feedback",
feedback
});
return feedback;
};
agent.getFeedback = () => agent.getSnapshot().context.feedback;
agent.addObservation = (observationInput) => {
const { prevState, event, state } = observationInput;
const observation = {
prevState,
event,
state,
id: observationInput.id ?? randomId(),
sessionId: agent.sessionId,
timestamp: observationInput.timestamp ?? Date.now(),
machineHash: observationInput.machine ? getMachineHash(observationInput.machine) : void 0
};
agent.send({
type: "agent.observe",
observation
});
return observation;
};
agent.getObservations = () => agent.getSnapshot().context.observations;
agent.addPlan = (plan) => {
agent.send({
type: "agent.plan",
plan
});
};
agent.getPlans = () => agent.getSnapshot().context.plans;
agent.interact = (actorRef, getInput) => {
let prevState = void 0;
let subscribed = true;
async function handleObservation(observationInput) {
const observation = agent.addObservation(observationInput);
const input = getInput?.(observation);
if (input) {
await agentDecide(agent, {
machine: actorRef.src,
state: observation.state,
execute: async (event) => {
actorRef.send(event);
},
...input
});
}
prevState = observationInput.state;
}
actorRef.system.inspect({
next: async (inspEvent) => {
if (!subscribed || inspEvent.actorRef !== actorRef || inspEvent.type !== "@xstate.snapshot") {
return;
}
const observationInput = {
event: inspEvent.event,
prevState,
state: inspEvent.snapshot,
machine: actorRef.src
};
await handleObservation(observationInput);
}
});
if (actorRef._processingStatus === 1) {
handleObservation({
prevState: void 0,
event: { type: "" },
// TODO: unknown events?
state: actorRef.getSnapshot(),
machine: actorRef.src
});
}
return {
unsubscribe: () => {
subscribed = false;
}
// TODO: make this actually unsubscribe
};
};
agent.types = {};
agent.start();
return agent;
}
// Annotate the CommonJS export names for ESM import in node:
0 && (module.exports = {
createAgent,
fromDecision,
fromText,
fromTextStream
});