@twilio-alpha/assistants-eval
Version:
promptfoo extension for writing AI evaluations for Twilio AI Assistants
230 lines (204 loc) • 5.63 kB
text/typescript
import {
AssertionValueFunctionContext,
AssertionValueFunctionResult,
GradingResult,
} from 'promptfoo';
import { z } from 'zod';
import { TwilioAgentProvider, TwilioProvider } from '../providers';
import { TwilioProviderResponse } from '../providers/twilio';
type TwilioProviderContext = AssertionValueFunctionContext & {
provider: TwilioAgentProvider | TwilioProvider;
providerResponse: TwilioProviderResponse;
config?: {
expectedTools: {
name?: string;
input?: string;
output?: string;
}[];
};
};
function isTwilioProviderContext(
context: AssertionValueFunctionContext,
): context is TwilioProviderContext {
if (
!(context.provider instanceof TwilioAgentProvider) &&
!(context.provider instanceof TwilioProvider)
) {
return false;
}
return context.providerResponse?.metadata?.sessionId;
}
const MessageBase = z.object({
account_sid: z.string(),
assistant_id: z.string(),
date_created: z.string(),
date_updated: z.string(),
id: z.string(),
identity: z.string(),
});
export function sanitizeFunctionCallName(name: string): string {
return name.replace(/[^a-zA-Z0-9_-]+/g, '_');
}
const MessageSchema = z.discriminatedUnion('role', [
z
.object({
role: z.literal('user'),
content: z.object({
content: z.string(),
}),
meta: z.object({}),
})
.merge(MessageBase),
z
.object({
role: z.literal('assistant'),
content: z.object({
content: z.string(),
}),
meta: z.object({
tokens: z.object({
completionTokens: z.number(),
promptTokens: z.number(),
totalTokens: z.number(),
}),
}),
})
.merge(MessageBase),
z
.object({
role: z.literal('tool'),
content: z.object({
input: z.string(),
output: z.string(),
name: z.string(),
}),
meta: z.object({}),
})
.merge(MessageBase),
]);
export type Message = z.infer<typeof MessageSchema>;
const History = z.object({
messages: z.array(MessageSchema),
meta: z.object({
first_page_url: z.string().or(z.null()),
next_page_url: z.string().or(z.null()),
previous_page_url: z.string().or(z.null()),
url: z.string(),
key: z.literal('messages'),
page: z.number(),
page_size: z.number(),
}),
});
async function getHistory(
sessionId: string,
authorizationHeader: string,
domain: string,
) {
const url = new URL(`/v1/Sessions/${sessionId}/Messages`, domain);
url.searchParams.append('PageSize', '100');
const response = await fetch(url, {
headers: {
Authorization: authorizationHeader,
'Content-Type': 'application/json',
},
});
const body = await response.json();
const result = History.parse(body);
return result.messages;
}
export function findAllToolCalls(messages: Message[]) {
return messages.filter((message) => message.role === 'tool');
}
export function findToolCallsForResponse(
messages: Message[],
response: string,
) {
const indexOfAiResponse = messages.findIndex(
(message) =>
message.role === 'assistant' && message.content.content === response,
);
const indexOfUserMessage = messages.findIndex(
(message) => message.role === 'user',
indexOfAiResponse,
);
return findAllToolCalls(
messages.slice(indexOfAiResponse, indexOfUserMessage),
);
}
export async function usedTool(
_output: string,
context: AssertionValueFunctionContext,
): Promise<AssertionValueFunctionResult> {
if (!isTwilioProviderContext(context)) {
return {
pass: false,
score: 0,
reason:
'Assertion can only be used in with TwilioProvider or TwilioAgentProvider',
};
}
const { provider, providerResponse } = context;
let authorizationHeader = '';
let url = '';
if (provider instanceof TwilioAgentProvider) {
authorizationHeader =
// @ts-ignore
provider.agentProviderInstance?.requestOptions?.headers?.Authorization ||
'';
url = provider.agentProviderInstance.defaultUrl;
} else if (provider instanceof TwilioProvider) {
// @ts-ignore
authorizationHeader = provider.requestOptions.headers.Authorization || '';
url = provider.defaultUrl;
}
if (!providerResponse.metadata?.sessionId) {
return {
pass: false,
score: 0,
reason: 'Invalid request',
};
}
const history = await getHistory(
providerResponse.metadata?.sessionId,
authorizationHeader,
url,
);
const tools = findAllToolCalls(history);
const expectedTools = context.config?.expectedTools || [];
const toolTests = expectedTools.map((expectedTool): GradingResult => {
const tool = tools.find((t) => {
if (
expectedTool.name &&
!sanitizeFunctionCallName(t.content.name).includes(
sanitizeFunctionCallName(expectedTool.name),
)
) {
return false;
}
if (expectedTool.input && !t.content.input.includes(expectedTool.input)) {
return false;
}
return !(
expectedTool.output && !t.content.output.includes(expectedTool.output)
);
});
return {
pass: !!tool,
score: tool ? 1 : 0,
reason: tool
? `Tool ${JSON.stringify(expectedTool)} found`
: `Tool ${JSON.stringify(expectedTool)} not found`,
assertion: {
type: 'javascript',
value: tool?.content,
},
};
});
const pass = toolTests.every((test) => test.pass);
return {
pass,
score: pass ? 1 : 0,
reason: pass ? 'Tools used' : 'Tools not used',
componentResults: toolTests,
};
}