UNPKG

@twilio-alpha/assistants-eval

Version:

promptfoo extension for writing AI evaluations for Twilio AI Assistants

230 lines (204 loc) 5.63 kB
import { AssertionValueFunctionContext, AssertionValueFunctionResult, GradingResult, } from 'promptfoo'; import { z } from 'zod'; import { TwilioAgentProvider, TwilioProvider } from '../providers'; import { TwilioProviderResponse } from '../providers/twilio'; type TwilioProviderContext = AssertionValueFunctionContext & { provider: TwilioAgentProvider | TwilioProvider; providerResponse: TwilioProviderResponse; config?: { expectedTools: { name?: string; input?: string; output?: string; }[]; }; }; function isTwilioProviderContext( context: AssertionValueFunctionContext, ): context is TwilioProviderContext { if ( !(context.provider instanceof TwilioAgentProvider) && !(context.provider instanceof TwilioProvider) ) { return false; } return context.providerResponse?.metadata?.sessionId; } const MessageBase = z.object({ account_sid: z.string(), assistant_id: z.string(), date_created: z.string(), date_updated: z.string(), id: z.string(), identity: z.string(), }); export function sanitizeFunctionCallName(name: string): string { return name.replace(/[^a-zA-Z0-9_-]+/g, '_'); } const MessageSchema = z.discriminatedUnion('role', [ z .object({ role: z.literal('user'), content: z.object({ content: z.string(), }), meta: z.object({}), }) .merge(MessageBase), z .object({ role: z.literal('assistant'), content: z.object({ content: z.string(), }), meta: z.object({ tokens: z.object({ completionTokens: z.number(), promptTokens: z.number(), totalTokens: z.number(), }), }), }) .merge(MessageBase), z .object({ role: z.literal('tool'), content: z.object({ input: z.string(), output: z.string(), name: z.string(), }), meta: z.object({}), }) .merge(MessageBase), ]); export type Message = z.infer<typeof MessageSchema>; const History = z.object({ messages: z.array(MessageSchema), meta: z.object({ first_page_url: z.string().or(z.null()), next_page_url: z.string().or(z.null()), previous_page_url: z.string().or(z.null()), url: z.string(), key: z.literal('messages'), page: z.number(), page_size: z.number(), }), }); async function getHistory( sessionId: string, authorizationHeader: string, domain: string, ) { const url = new URL(`/v1/Sessions/${sessionId}/Messages`, domain); url.searchParams.append('PageSize', '100'); const response = await fetch(url, { headers: { Authorization: authorizationHeader, 'Content-Type': 'application/json', }, }); const body = await response.json(); const result = History.parse(body); return result.messages; } export function findAllToolCalls(messages: Message[]) { return messages.filter((message) => message.role === 'tool'); } export function findToolCallsForResponse( messages: Message[], response: string, ) { const indexOfAiResponse = messages.findIndex( (message) => message.role === 'assistant' && message.content.content === response, ); const indexOfUserMessage = messages.findIndex( (message) => message.role === 'user', indexOfAiResponse, ); return findAllToolCalls( messages.slice(indexOfAiResponse, indexOfUserMessage), ); } export async function usedTool( _output: string, context: AssertionValueFunctionContext, ): Promise<AssertionValueFunctionResult> { if (!isTwilioProviderContext(context)) { return { pass: false, score: 0, reason: 'Assertion can only be used in with TwilioProvider or TwilioAgentProvider', }; } const { provider, providerResponse } = context; let authorizationHeader = ''; let url = ''; if (provider instanceof TwilioAgentProvider) { authorizationHeader = // @ts-ignore provider.agentProviderInstance?.requestOptions?.headers?.Authorization || ''; url = provider.agentProviderInstance.defaultUrl; } else if (provider instanceof TwilioProvider) { // @ts-ignore authorizationHeader = provider.requestOptions.headers.Authorization || ''; url = provider.defaultUrl; } if (!providerResponse.metadata?.sessionId) { return { pass: false, score: 0, reason: 'Invalid request', }; } const history = await getHistory( providerResponse.metadata?.sessionId, authorizationHeader, url, ); const tools = findAllToolCalls(history); const expectedTools = context.config?.expectedTools || []; const toolTests = expectedTools.map((expectedTool): GradingResult => { const tool = tools.find((t) => { if ( expectedTool.name && !sanitizeFunctionCallName(t.content.name).includes( sanitizeFunctionCallName(expectedTool.name), ) ) { return false; } if (expectedTool.input && !t.content.input.includes(expectedTool.input)) { return false; } return !( expectedTool.output && !t.content.output.includes(expectedTool.output) ); }); return { pass: !!tool, score: tool ? 1 : 0, reason: tool ? `Tool ${JSON.stringify(expectedTool)} found` : `Tool ${JSON.stringify(expectedTool)} not found`, assertion: { type: 'javascript', value: tool?.content, }, }; }); const pass = toolTests.every((test) => test.pass); return { pass, score: pass ? 1 : 0, reason: pass ? 'Tools used' : 'Tools not used', componentResults: toolTests, }; }