UNPKG

@statelyai/agent

Version:

Stateful agents that make decisions based on finite-state machine models

507 lines (429 loc) 11.3 kB
import { test, expect, vi } from 'vitest'; import { AgentGenerateTextResult, AgentMessage, createAgent, type AIAdapter, } from './'; import { createActor, createMachine } from 'xstate'; import { GenerateTextResult } from 'ai'; import { z } from 'zod'; test('an agent has the expected interface', () => { const agent = createAgent({ name: 'test', events: {}, model: {} as any, }); expect(agent.decide).toBeDefined(); expect(agent.generateText).toBeDefined(); expect(agent.streamText).toBeDefined(); expect(agent.addMessage).toBeDefined(); expect(agent.addObservation).toBeDefined(); expect(agent.addFeedback).toBeDefined(); expect(agent.addPlan).toBeDefined(); expect(agent.getMessages).toBeDefined(); expect(agent.getObservations).toBeDefined(); expect(agent.getFeedback).toBeDefined(); expect(agent.getPlans).toBeDefined(); expect(agent.interact).toBeDefined(); }); test('agent.addMessage() adds to message history', () => { const agent = createAgent({ name: 'test', events: {}, model: {} as any, }); agent.addMessage({ content: 'msg 1', role: 'user', }); const messageHistory = agent.addMessage({ content: 'response 1', role: 'assistant', }); expect(messageHistory.sessionId).toEqual(agent.sessionId); expect(agent.select((c) => c.messages)).toContainEqual( expect.objectContaining({ content: 'msg 1', }) ); expect(agent.getMessages()).toContainEqual( expect.objectContaining({ content: 'msg 1', }) ); expect(agent.select((c) => c.messages)).toContainEqual( expect.objectContaining({ content: 'response 1', sessionId: expect.any(String), timestamp: expect.any(Number), }) ); expect(agent.getMessages()).toContainEqual( expect.objectContaining({ content: 'response 1', sessionId: expect.any(String), timestamp: expect.any(Number), }) ); }); test('agent.addFeedback() adds to feedback', () => { const agent = createAgent({ name: 'test', events: {}, model: {} as any, }); const feedback = agent.addFeedback({ attributes: { score: -1, }, goal: 'Win the game', observationId: 'obs-1', }); expect(feedback.sessionId).toEqual(agent.sessionId); expect(agent.select((c) => c.feedback)).toContainEqual( expect.objectContaining({ attributes: { score: -1, }, goal: 'Win the game', observationId: 'obs-1', sessionId: expect.any(String), timestamp: expect.any(Number), }) ); expect(agent.getFeedback()).toContainEqual( expect.objectContaining({ attributes: { score: -1, }, goal: 'Win the game', observationId: 'obs-1', sessionId: expect.any(String), timestamp: expect.any(Number), }) ); }); test('agent.addObservation() adds to observations', () => { const agent = createAgent({ name: 'test', events: {}, model: {} as any, }); const observation = agent.addObservation({ prevState: { value: 'playing', context: {} }, event: { type: 'play', position: 3 }, state: { value: 'lost', context: {} }, }); expect(observation.sessionId).toEqual(agent.sessionId); expect(agent.select((c) => c.observations)).toContainEqual( expect.objectContaining({ prevState: { value: 'playing', context: {} }, event: { type: 'play', position: 3 }, state: { value: 'lost', context: {} }, sessionId: expect.any(String), timestamp: expect.any(Number), }) ); }); test('agent.addObservation() adds to observations with machine hash', () => { const agent = createAgent({ name: 'test', events: {}, model: {} as any, }); const machine = createMachine({ initial: 'playing', states: { playing: { on: { play: 'lost', }, }, lost: {}, }, }); const observation = agent.addObservation({ prevState: { value: 'playing', context: {} }, event: { type: 'play', position: 3 }, state: { value: 'lost', context: {} }, machine, }); expect(observation.sessionId).toEqual(agent.sessionId); expect(agent.select((c) => c.observations)).toContainEqual( expect.objectContaining({ prevState: { value: 'playing', context: {} }, event: { type: 'play', position: 3 }, state: { value: 'lost', context: {} }, machineHash: expect.any(String), sessionId: expect.any(String), timestamp: expect.any(Number), }) ); }); test('agent.interact() observes machine actors (no 2nd arg)', () => { const machine = createMachine({ initial: 'a', states: { a: { on: { NEXT: 'b' }, }, b: {}, }, }); const agent = createAgent({ name: 'test', events: {}, model: {} as any, }); const actor = createActor(machine); agent.interact(actor); actor.start(); expect(agent.select((c) => c.observations)).toContainEqual( expect.objectContaining({ prevState: undefined, state: expect.objectContaining({ value: 'a' }), }) ); expect(agent.getObservations()).toContainEqual( expect.objectContaining({ prevState: undefined, state: expect.objectContaining({ value: 'a' }), }) ); actor.send({ type: 'NEXT' }); expect(agent.select((c) => c.observations)).toContainEqual( expect.objectContaining({ prevState: expect.objectContaining({ value: 'a' }), event: { type: 'NEXT' }, state: expect.objectContaining({ value: 'b' }), }) ); }); test('Agents can use a custom adapter', async () => { const adapter = { generateText: async () => { return { text: 'Response', } as any; }, } as unknown as AIAdapter; const agent = createAgent({ name: 'test', events: {}, adapter, model: {} as any, }); const res = await agent.generateText({ prompt: 'Question?', }); expect(res.text).toEqual('Response'); }); test('You can listen for feedback events', () => { const fn = vi.fn(); const agent = createAgent({ name: 'test', events: {}, adapter: {} as any, model: {} as any, }); agent.on('feedback', fn); agent.addFeedback({ attributes: { score: -1, }, goal: 'Win the game', observationId: 'obs-1', }); expect(fn).toHaveBeenCalled(); }); test('You can listen for plan events', async () => { const fn = vi.fn(); const agent = createAgent({ name: 'test', model: {} as any, events: { WIN: z.object({}), }, adapter: { generateText: async (arg) => { const keys = Object.keys(arg.tools!); if (keys.length !== 1) { throw new Error('Expected only 1 choice'); } return { toolResults: [ { result: { type: keys[0], }, }, ], } as any as AgentGenerateTextResult; }, streamText: {} as any, }, }); agent.on('plan', fn); await agent.decide({ goal: 'Win the game', state: { value: 'playing', context: {}, }, machine: createMachine({ initial: 'playing', states: { playing: { on: { WIN: { target: 'won', }, }, }, won: {}, }, }), }); expect(fn).toHaveBeenCalledWith( expect.objectContaining({ plan: expect.objectContaining({ nextEvent: { type: 'WIN', }, }), }) ); }); test('agent.types provides context and event types', () => { const agent = createAgent({ model: {} as any, events: { setScore: z.object({ score: z.number(), }), }, context: { score: z.number(), }, }); agent.types satisfies { context: any; events: any }; agent.types.context satisfies { score: number }; // @ts-expect-error agent.types.context satisfies { score: string }; }); test.each(['generateText', 'streamText'] as const)( 'can provide a correlation ID (%s)', async (method) => { const agent = createAgent({ model: {} as any, events: {}, adapter: { [method]: async (opts: any) => { const res = { text: 'response', }; opts.onFinish?.(res); return res as AgentGenerateTextResult; }, } as any as AIAdapter, }); const promise = new Promise<AgentMessage>((res) => { agent.onMessage((msg) => { if (msg.role === 'assistant') { res(msg); } }); }); await agent[method]({ prompt: 'hi', correlationId: 'c-1', }); const msg = await promise; expect(msg.correlationId).toBe('c-1'); expect(msg.parentCorrelationId).toBe(undefined); } ); test.each(['generateText', 'streamText'] as const)( 'correlation IDs are automatically generated if not provided (%s)', async (method) => { const agent = createAgent({ model: {} as any, events: {}, adapter: { [method]: async (opts: any) => { const res = { text: 'response', }; opts.onFinish?.(res); return res as AgentGenerateTextResult; }, } as any as AIAdapter, }); await agent[method]({ prompt: 'hi', }); const messages = agent.getMessages(); expect(messages[0]?.correlationId).toEqual(expect.stringMatching(/.+/)); expect(messages[0]?.role).toBe('user'); expect(messages[1]?.correlationId).toEqual(expect.stringMatching(/.+/)); expect(messages[1]?.role).toBe('assistant'); expect(messages[0]!.correlationId).toEqual(messages[1]!.correlationId); } ); test.each(['generateText', 'streamText'] as const)( 'can provide a parent correlation ID (%s)', async (method) => { const agent = createAgent({ model: {} as any, events: {}, adapter: { [method]: async (opts: any) => { const res = { text: 'response', }; opts.onFinish?.(res); return res as AgentGenerateTextResult; }, } as any as AIAdapter, }); await agent[method]({ prompt: 'hi', correlationId: 'c-1', parentCorrelationId: 'c-0', }); const msg = agent.getMessages().find((msg) => msg.role === 'assistant')!; expect(msg.correlationId).toBe('c-1'); expect(msg.parentCorrelationId).toBe('c-0'); } ); test.each(['generateText', 'streamText'] as const)( 'can add feedback to a correlation (%s)', async (method) => { const agent = createAgent({ name: 'test', model: {} as any, events: {}, adapter: { [method]: async (opts: any) => { const res = { text: 'response', }; opts.onFinish?.(res); return res as AgentGenerateTextResult; }, } as any as AIAdapter, }); const res = await agent[method]({ prompt: 'test', }); agent.addFeedback({ correlationId: res.correlationId, reward: -1, }); const message = agent.getMessages()[0]!; const feedback = agent.getFeedback()[0]!; expect(message.correlationId).toBeDefined(); expect(feedback.correlationId).toEqual(message.correlationId); } );