@statelyai/agent
Version:
Stateful agents that make decisions based on finite-state machine models
589 lines (578 loc) • 16.1 kB
JavaScript
// src/agent.ts
import {
createActor,
fromTransition
} from "xstate";
// src/planners/simplePlanner.ts
import { tool } from "ai";
// src/utils.ts
import hash from "object-hash";
function getAllTransitions(state) {
const nodes = state._nodes;
const transitions = nodes.map((node) => [...node.transitions.values()]).map((nodeTransitions) => {
return nodeTransitions.map((nodeEventTransitions) => {
return nodeEventTransitions.map((transition) => {
return {
...transition,
guard: typeof transition.guard === "string" ? { type: transition.guard } : transition.guard
// TODO: fix
};
});
});
}).flat(2);
return transitions;
}
function getAllMachineTransitions(stateNode) {
const transitions = [...stateNode.transitions.values()].map((nodeTransitions) => {
return nodeTransitions.map((transition) => {
return {
...transition,
guard: typeof transition.guard === "string" ? { type: transition.guard } : transition.guard
// TODO: fix
};
});
}).flat(2);
for (const s of Object.values(stateNode.states)) {
const stateTransitions = getAllMachineTransitions(s);
transitions.push(...stateTransitions);
}
return transitions;
}
function wrapInXml(tagName, content) {
return `<${tagName}>${content}</${tagName}>`;
}
function randomId() {
const timestamp = Date.now().toString(36);
const random = Math.random().toString(36).substring(2, 9);
return timestamp + random;
}
var machineHashes = /* @__PURE__ */ new WeakMap();
function getMachineHash(machine) {
if (machineHashes.has(machine)) return machineHashes.get(machine);
const transitions = getAllMachineTransitions(machine.root);
const machineHash = hash(transitions);
machineHashes.set(machine, machineHash);
return machineHash;
}
// src/templates/defaultText.ts
var defaultTextTemplate = (data) => {
const preamble = [
data.context ? wrapInXml("context", JSON.stringify(data.context)) : void 0
].filter(Boolean).join("\n");
return `
${preamble}
${data.goal}
`.trim();
};
// src/text.ts
import {
fromObservable,
fromPromise,
toObserver
} from "xstate";
async function getMessages(agent, prompt, options) {
let messages = [];
if (typeof options.messages === "function") {
messages = await options.messages(agent);
} else if (options.messages) {
messages = options.messages;
}
messages = messages.concat({
role: "user",
content: prompt
});
return messages;
}
async function agentGenerateText(agent, options) {
const resolvedOptions = {
...agent.defaultOptions,
...options,
correlationId: options.correlationId ?? randomId()
};
const template = resolvedOptions.template ?? defaultTextTemplate;
const id = randomId();
const goal = typeof resolvedOptions.prompt === "string" ? resolvedOptions.prompt : await resolvedOptions.prompt(agent);
const promptWithContext = template({
goal,
context: resolvedOptions.context
});
const messages = await getMessages(agent, promptWithContext, resolvedOptions);
agent.addMessage({
id,
role: "user",
content: promptWithContext,
timestamp: Date.now(),
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId
});
const result = await agent.adapter.generateText({
...resolvedOptions,
prompt: void 0,
messages
});
agent.addMessage({
content: result.text,
id,
role: "assistant",
timestamp: Date.now(),
responseId: id,
result,
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId
});
return {
...result,
parentCorrelationId: resolvedOptions.parentCorrelationId,
correlationId: resolvedOptions.correlationId
};
}
async function agentStreamText(agent, options) {
const resolvedOptions = {
...agent.defaultOptions,
...options,
correlationId: options.correlationId ?? randomId()
};
const template = resolvedOptions.template ?? defaultTextTemplate;
const id = randomId();
const goal = typeof resolvedOptions.prompt === "string" ? resolvedOptions.prompt : await resolvedOptions.prompt(agent);
const promptWithContext = template({
goal,
context: resolvedOptions.context
});
const messages = await getMessages(agent, promptWithContext, resolvedOptions);
agent.addMessage({
role: "user",
content: promptWithContext,
id,
timestamp: Date.now(),
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId
});
const result = await agent.adapter.streamText({
...resolvedOptions,
prompt: void 0,
messages,
onFinish: async (res) => {
agent.addMessage({
role: "assistant",
result: {
text: res.text,
finishReason: res.finishReason,
logprobs: void 0,
responseMessages: [],
toolCalls: [],
toolResults: [],
usage: res.usage,
warnings: res.warnings,
rawResponse: res.rawResponse,
roundtrips: [],
// TODO: how do we get this information?,
steps: res.steps,
response: res.response,
experimental_providerMetadata: res.experimental_providerMetadata
},
content: res.text,
id: randomId(),
timestamp: Date.now(),
responseId: id,
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId
});
}
});
return {
...result,
textStream: result.textStream,
fullStream: result.fullStream,
parentCorrelationId: resolvedOptions.parentCorrelationId,
correlationId: resolvedOptions.correlationId
};
}
function fromTextStream(agent, defaultOptions) {
return fromObservable(({ input }) => {
const observers = /* @__PURE__ */ new Set();
(async () => {
const result = await agentStreamText(agent, {
...defaultOptions,
...input,
context: input.context
});
for await (const part of result.fullStream) {
if (part.type === "text-delta") {
observers.forEach((observer) => {
observer.next?.(part);
});
}
}
})();
return {
subscribe: (...args) => {
const observer = toObserver(...args);
observers.add(observer);
return {
unsubscribe: () => {
observers.delete(observer);
}
};
}
};
});
}
function fromText(agent, defaultOptions) {
return fromPromise(async ({ input }) => {
return await agentGenerateText(agent, {
...input,
...defaultOptions,
context: input.context
});
});
}
// src/planners/simplePlanner.ts
function getTransitions(state, machine) {
if (!machine) {
return [];
}
const resolvedState = machine.resolveState(state);
return getAllTransitions(resolvedState);
}
var simplePlannerPromptTemplate = (data) => {
return `
${defaultTextTemplate(data)}
Make at most one tool call to achieve the above goal. If the goal cannot be achieved with any tool calls, do not make any tool call.
`.trim();
};
async function simplePlanner(agent, input) {
const transitions = input.machine ? getTransitions(input.state, input.machine) : Object.entries(input.events).map(([eventType, { description }]) => ({
eventType,
description
}));
const filter = (eventType) => Object.keys(input.events).includes(eventType);
const functionNameMapping = {};
const toolTransitions = transitions.filter((t) => {
return filter(t.eventType);
}).map((t) => {
const name = t.eventType.replace(/\./g, "_");
functionNameMapping[name] = t.eventType;
return {
type: "function",
eventType: t.eventType,
description: t.description,
name
};
});
const toolMap = {};
for (const toolTransitionData of toolTransitions) {
const toolZodType = input.events?.[toolTransitionData.eventType];
if (!toolZodType) {
continue;
}
toolMap[toolTransitionData.name] = tool({
description: toolZodType?.description ?? toolTransitionData.description,
parameters: toolZodType,
execute: async (params) => {
const event = {
type: toolTransitionData.eventType,
...params
};
return event;
}
});
}
if (!Object.keys(toolMap).length) {
return void 0;
}
const prompt = simplePlannerPromptTemplate({
context: input.state.context,
goal: input.goal
});
const messages = await getMessages(agent, prompt, input);
const result = await agent.generateText({
toolChoice: "required",
...input,
prompt,
messages,
tools: toolMap
});
const singleResult = result.toolResults[0];
if (!singleResult) {
console.warn("No tool call results returned");
return void 0;
}
return {
goal: input.goal,
state: input.state,
execute: async (state) => {
if (JSON.stringify(state) === JSON.stringify(input.state)) {
return singleResult.result;
}
return void 0;
},
nextEvent: singleResult.result,
sessionId: agent.sessionId,
timestamp: Date.now()
};
}
// src/decision.ts
import { fromPromise as fromPromise2 } from "xstate";
async function agentDecide(agent, options) {
const resolvedOptions = {
...agent.defaultOptions,
...options
};
const {
planner = simplePlanner,
goal,
events = agent.events,
state,
machine,
model = agent.model,
...otherPlanInput
} = resolvedOptions;
const plan = await planner(agent, {
model,
goal,
events,
state,
machine,
...otherPlanInput
});
if (plan?.nextEvent) {
agent.addPlan(plan);
await resolvedOptions.execute?.(plan.nextEvent);
}
return plan;
}
function fromDecision(agent, defaultInput) {
return fromPromise2(async ({ input, self }) => {
const parentRef = self._parent;
if (!parentRef) {
return;
}
const snapshot = parentRef.getSnapshot();
const inputObject = typeof input === "string" ? { goal: input } : input;
const resolvedInput = {
...defaultInput,
...inputObject
};
const contextToInclude = resolvedInput.context === true ? (
// include entire context
parentRef.getSnapshot().context
) : resolvedInput.context;
const state = {
value: snapshot.value,
context: contextToInclude
};
const plan = await agentDecide(agent, {
machine: parentRef.logic,
state,
execute: async (event) => {
parentRef.send(event);
},
...resolvedInput
});
return plan;
});
}
// src/adapters/vercel.ts
import { generateText, streamText } from "ai";
var vercelAdapter = {
generateText,
streamText
};
// src/agent.ts
var agentLogic = fromTransition(
(state, event, { emit }) => {
switch (event.type) {
case "agent.feedback": {
state.feedback.push(event.feedback);
emit({
type: "feedback",
// @ts-ignore TODO: fix types in XState
feedback: event.feedback
});
break;
}
case "agent.observe": {
state.observations.push(event.observation);
emit({
type: "observation",
// @ts-ignore TODO: fix types in XState
observation: event.observation
});
break;
}
case "agent.message": {
state.messages.push(event.message);
emit({
type: "message",
// @ts-ignore TODO: fix types in XState
message: event.message
});
break;
}
case "agent.plan": {
state.plans.push(event.plan);
emit({
type: "plan",
// @ts-ignore TODO: fix types in XState
plan: event.plan
});
break;
}
default:
break;
}
return state;
},
() => ({
feedback: [],
messages: [],
observations: [],
plans: []
})
);
function createAgent({
name,
description,
model,
events,
context,
planner = simplePlanner,
stringify = JSON.stringify,
getMemory,
logic = agentLogic,
adapter = vercelAdapter,
...generateTextOptions
}) {
const agent = createActor(logic);
agent.events = events;
agent.model = model;
agent.name = name;
agent.description = description;
agent.adapter = adapter;
agent.defaultOptions = { ...generateTextOptions, model };
agent.select = (selector) => {
return selector(agent.getSnapshot().context);
};
agent.memory = getMemory ? getMemory(agent) : void 0;
agent.onMessage = (callback) => {
agent.on("message", (ev) => callback(ev.message));
};
agent.decide = (opts) => {
return agentDecide(agent, opts);
};
agent.addMessage = (messageInput) => {
const message = {
...messageInput,
id: messageInput.id ?? randomId(),
timestamp: messageInput.timestamp ?? Date.now(),
sessionId: agent.sessionId,
correlationId: messageInput.correlationId ?? randomId()
};
agent.send({
type: "agent.message",
message
});
return message;
};
agent.getMessages = () => agent.getSnapshot().context.messages;
agent.generateText = (opts) => agentGenerateText(agent, opts);
agent.streamText = (opts) => agentStreamText(agent, opts);
agent.addFeedback = (feedbackInput) => {
const feedback = {
...feedbackInput,
attributes: { ...feedbackInput.attributes },
reward: feedbackInput.reward ?? 0,
timestamp: feedbackInput.timestamp ?? Date.now(),
sessionId: agent.sessionId
};
agent.send({
type: "agent.feedback",
feedback
});
return feedback;
};
agent.getFeedback = () => agent.getSnapshot().context.feedback;
agent.addObservation = (observationInput) => {
const { prevState, event, state } = observationInput;
const observation = {
prevState,
event,
state,
id: observationInput.id ?? randomId(),
sessionId: agent.sessionId,
timestamp: observationInput.timestamp ?? Date.now(),
machineHash: observationInput.machine ? getMachineHash(observationInput.machine) : void 0
};
agent.send({
type: "agent.observe",
observation
});
return observation;
};
agent.getObservations = () => agent.getSnapshot().context.observations;
agent.addPlan = (plan) => {
agent.send({
type: "agent.plan",
plan
});
};
agent.getPlans = () => agent.getSnapshot().context.plans;
agent.interact = (actorRef, getInput) => {
let prevState = void 0;
let subscribed = true;
async function handleObservation(observationInput) {
const observation = agent.addObservation(observationInput);
const input = getInput?.(observation);
if (input) {
await agentDecide(agent, {
machine: actorRef.src,
state: observation.state,
execute: async (event) => {
actorRef.send(event);
},
...input
});
}
prevState = observationInput.state;
}
actorRef.system.inspect({
next: async (inspEvent) => {
if (!subscribed || inspEvent.actorRef !== actorRef || inspEvent.type !== "@xstate.snapshot") {
return;
}
const observationInput = {
event: inspEvent.event,
prevState,
state: inspEvent.snapshot,
machine: actorRef.src
};
await handleObservation(observationInput);
}
});
if (actorRef._processingStatus === 1) {
handleObservation({
prevState: void 0,
event: { type: "" },
// TODO: unknown events?
state: actorRef.getSnapshot(),
machine: actorRef.src
});
}
return {
unsubscribe: () => {
subscribed = false;
}
// TODO: make this actually unsubscribe
};
};
agent.types = {};
agent.start();
return agent;
}
export {
createAgent,
fromDecision,
fromText,
fromTextStream
};