@mastra/core
Version:
Mastra is a framework for building AI-powered applications and agents with a modern TypeScript stack.
423 lines (419 loc) • 12.9 kB
JavaScript
import { createStep, createWorkflow } from '../../chunk-Z6QCWTTO.js';
import { saveScorePayloadSchema } from '../../chunk-UWTYVVVZ.js';
import { convertMessages } from '../../chunk-Q6LWNLAJ.js';
import { MastraError } from '../../chunk-PZUZNPFM.js';
import pMap from 'p-map';
import z from 'zod';
// src/scores/scoreTraces/scoreTraces.ts
async function scoreTraces({
scorerName,
targets,
mastra
}) {
const workflow = mastra.__getInternalWorkflow("__batch-scoring-traces");
try {
const run = await workflow.createRunAsync();
await run.start({ inputData: { targets, scorerName } });
} catch (error) {
const mastraError = new MastraError(
{
category: "SYSTEM",
domain: "SCORER",
id: "MASTRA_SCORER_FAILED_TO_RUN_TRACE_SCORING",
details: {
scorerName,
targets: JSON.stringify(targets)
}
},
error
);
mastra.getLogger()?.trackException(mastraError);
mastra.getLogger()?.error(mastraError.toString());
}
}
// src/scores/scoreTraces/utils.ts
function buildSpanTree(spans) {
const spanMap = /* @__PURE__ */ new Map();
const childrenMap = /* @__PURE__ */ new Map();
const rootSpans = [];
for (const span of spans) {
spanMap.set(span.spanId, span);
}
for (const span of spans) {
if (span.parentSpanId === null) {
rootSpans.push(span);
} else {
const siblings = childrenMap.get(span.parentSpanId) || [];
siblings.push(span);
childrenMap.set(span.parentSpanId, siblings);
}
}
for (const children of childrenMap.values()) {
children.sort((a, b) => new Date(a.startedAt).getTime() - new Date(b.startedAt).getTime());
}
rootSpans.sort((a, b) => new Date(a.startedAt).getTime() - new Date(b.startedAt).getTime());
return { spanMap, childrenMap, rootSpans };
}
function getChildrenOfType(spanTree, parentSpanId, spanType) {
const children = spanTree.childrenMap.get(parentSpanId) || [];
return children.filter((span) => span.spanType === spanType);
}
function normalizeMessageContent(content) {
if (typeof content === "string") {
return content;
}
const tempMessage = {
id: "temp",
role: "user",
parts: content.map((part) => ({ type: part.type, text: part.text }))
};
const converted = convertMessages(tempMessage).to("AIV4.UI");
return converted[0]?.content || "";
}
function convertToUIMessage(message, createdAt) {
let messageInput;
if (typeof message.content === "string") {
messageInput = {
id: "temp",
role: message.role,
content: message.content
};
} else {
messageInput = {
id: "temp",
role: message.role,
parts: message.content.map((part) => ({ type: part.type, text: part.text }))
};
}
const converted = convertMessages(messageInput).to("AIV4.UI");
const result = converted[0];
if (!result) {
throw new Error("Failed to convert message");
}
return {
...result,
id: "",
// Spans don't have message IDs
createdAt: new Date(createdAt)
// Use span timestamp
};
}
function extractInputMessages(agentSpan) {
const input = agentSpan.input;
if (typeof input === "string") {
return [
{
role: "user",
content: input,
createdAt: new Date(agentSpan.startedAt),
parts: [{ type: "text", text: input }],
experimental_attachments: []
}
];
}
if (Array.isArray(input)) {
return input.map((msg) => convertToUIMessage(msg, agentSpan.startedAt));
}
if (input && typeof input === "object" && Array.isArray(input.messages)) {
return input.messages.map((msg) => convertToUIMessage(msg, agentSpan.startedAt));
}
return [];
}
function extractSystemMessages(llmSpan) {
return (llmSpan.input?.messages || []).filter((msg) => msg.role === "system").map((msg) => ({
role: "system",
content: normalizeMessageContent(msg.content)
}));
}
function extractRememberedMessages(llmSpan, currentInputContent) {
const messages = (llmSpan.input?.messages || []).filter((msg) => msg.role !== "system").filter((msg) => normalizeMessageContent(msg.content) !== currentInputContent);
return messages.map((msg) => convertToUIMessage(msg, llmSpan.startedAt));
}
function reconstructToolInvocations(spanTree, parentSpanId) {
const toolSpans = getChildrenOfType(spanTree, parentSpanId, "tool_call" /* TOOL_CALL */);
return toolSpans.map((toolSpan) => ({
state: "result",
toolName: toolSpan.attributes?.toolId,
args: toolSpan.input || {},
result: toolSpan.output || {}
}));
}
function createMessageParts(toolInvocations, textContent) {
const parts = [];
for (const toolInvocation of toolInvocations) {
parts.push({
type: "tool-invocation",
toolInvocation
});
}
if (textContent.trim()) {
parts.push({
type: "text",
text: textContent
});
}
return parts;
}
function validateTrace(trace) {
if (!trace) {
throw new Error("Trace is null or undefined");
}
if (!trace.spans || !Array.isArray(trace.spans)) {
throw new Error("Trace must have a spans array");
}
if (trace.spans.length === 0) {
throw new Error("Trace has no spans");
}
const spanIds = new Set(trace.spans.map((span) => span.spanId));
for (const span of trace.spans) {
if (span.parentSpanId && !spanIds.has(span.parentSpanId)) {
throw new Error(`Span ${span.spanId} references non-existent parent ${span.parentSpanId}`);
}
}
}
function findPrimaryLLMSpan(spanTree, rootAgentSpan) {
const directLLMSpans = getChildrenOfType(spanTree, rootAgentSpan.spanId, "model_generation" /* MODEL_GENERATION */);
if (directLLMSpans.length > 0) {
return directLLMSpans[0];
}
throw new Error("No model generation span found in trace");
}
function prepareTraceForTransformation(trace) {
validateTrace(trace);
const spanTree = buildSpanTree(trace.spans);
const rootAgentSpan = spanTree.rootSpans.find((span) => span.spanType === "agent_run");
if (!rootAgentSpan) {
throw new Error("No root agent_run span found in trace");
}
return { spanTree, rootAgentSpan };
}
function transformTraceToScorerInputAndOutput(trace) {
const { spanTree, rootAgentSpan } = prepareTraceForTransformation(trace);
if (!rootAgentSpan.output) {
throw new Error("Root agent span has no output");
}
const primaryLLMSpan = findPrimaryLLMSpan(spanTree, rootAgentSpan);
const inputMessages = extractInputMessages(rootAgentSpan);
const systemMessages = extractSystemMessages(primaryLLMSpan);
const currentInputContent = inputMessages[0]?.content || "";
const rememberedMessages = extractRememberedMessages(primaryLLMSpan, currentInputContent);
const input = {
// We do not keep track of the tool call ids in traces, so we need to cast to UIMessageWithMetadata
inputMessages,
rememberedMessages,
systemMessages,
taggedSystemMessages: {}
// Todo: Support tagged system messages
};
const toolInvocations = reconstructToolInvocations(spanTree, rootAgentSpan.spanId);
const responseText = rootAgentSpan.output.text || "";
const responseMessage = {
role: "assistant",
content: responseText,
createdAt: new Date(rootAgentSpan.endedAt || rootAgentSpan.startedAt),
// @ts-ignore
parts: createMessageParts(toolInvocations, responseText),
experimental_attachments: [],
// Tool invocations are being deprecated however we need to support it for now
toolInvocations
};
const output = [responseMessage];
return {
input,
output
};
}
// src/scores/scoreTraces/scoreTracesWorkflow.ts
var getTraceStep = createStep({
id: "__process-trace-scoring",
inputSchema: z.object({
targets: z.array(
z.object({
traceId: z.string(),
spanId: z.string().optional()
})
),
scorerName: z.string()
}),
outputSchema: z.any(),
execute: async ({ inputData, tracingContext, mastra }) => {
const logger = mastra.getLogger();
if (!logger) {
console.warn(
"[scoreTracesWorkflow] Logger not initialized: no debug or error logs will be recorded for scoring traces."
);
}
const storage = mastra.getStorage();
if (!storage) {
const mastraError = new MastraError({
id: "MASTRA_STORAGE_NOT_FOUND_FOR_TRACE_SCORING",
domain: "STORAGE" /* STORAGE */,
category: "SYSTEM" /* SYSTEM */,
text: "Storage not found for trace scoring",
details: {
scorerName: inputData.scorerName
}
});
logger?.error(mastraError.toString());
logger?.trackException(mastraError);
return;
}
let scorer;
try {
scorer = mastra.getScorerByName(inputData.scorerName);
} catch (error) {
const mastraError = new MastraError(
{
id: "MASTRA_SCORER_NOT_FOUND_FOR_TRACE_SCORING",
domain: "SCORER" /* SCORER */,
category: "SYSTEM" /* SYSTEM */,
text: `Scorer not found for trace scoring`,
details: {
scorerName: inputData.scorerName
}
},
error
);
logger?.error(mastraError.toString());
logger?.trackException(mastraError);
return;
}
await pMap(
inputData.targets,
async (target) => {
try {
await runScorerOnTarget({ storage, scorer, target, tracingContext });
} catch (error) {
const mastraError = new MastraError(
{
id: "MASTRA_SCORER_FAILED_TO_RUN_SCORER_ON_TRACE",
domain: "SCORER" /* SCORER */,
category: "SYSTEM" /* SYSTEM */,
details: {
scorerName: scorer.name,
spanId: target.spanId || "",
traceId: target.traceId
}
},
error
);
logger?.error(mastraError.toString());
logger?.trackException(mastraError);
}
},
{ concurrency: 3 }
);
}
});
async function runScorerOnTarget({
storage,
scorer,
target,
tracingContext
}) {
const trace = await storage.getAITrace(target.traceId);
if (!trace) {
throw new Error(`Trace not found for scoring, traceId: ${target.traceId}`);
}
let span;
if (target.spanId) {
span = trace.spans.find((span2) => span2.spanId === target.spanId);
} else {
span = trace.spans.find((span2) => span2.parentSpanId === null);
}
if (!span) {
throw new Error(
`Span not found for scoring, traceId: ${target.traceId}, spanId: ${target.spanId ?? "Not provided"}`
);
}
const scorerRun = buildScorerRun({
scorerType: scorer.type === "agent" ? "agent" : void 0,
tracingContext,
trace,
targetSpan: span
});
const result = await scorer.run(scorerRun);
const scorerResult = {
...result,
scorer: {
id: scorer.name,
name: scorer.name,
description: scorer.description
},
traceId: target.traceId,
spanId: target.spanId,
entityId: span.name,
entityType: span.spanType,
entity: { traceId: span.traceId, spanId: span.spanId },
source: "TEST",
scorerId: scorer.name
};
const savedScoreRecord = await validateAndSaveScore({ storage, scorerResult });
await attachScoreToSpan({ storage, span, scoreRecord: savedScoreRecord });
}
async function validateAndSaveScore({ storage, scorerResult }) {
const payloadToSave = saveScorePayloadSchema.parse(scorerResult);
const result = await storage.saveScore(payloadToSave);
return result.score;
}
function buildScorerRun({
scorerType,
tracingContext,
trace,
targetSpan
}) {
let runPayload;
if (scorerType === "agent") {
const { input, output } = transformTraceToScorerInputAndOutput(trace);
runPayload = {
input,
output
};
} else {
runPayload = { input: targetSpan.input, output: targetSpan.output };
}
runPayload.tracingContext = tracingContext;
return runPayload;
}
async function attachScoreToSpan({
storage,
span,
scoreRecord
}) {
const existingLinks = span.links || [];
const link = {
type: "score",
scoreId: scoreRecord.id,
scorerName: scoreRecord.scorer.name,
score: scoreRecord.score,
createdAt: scoreRecord.createdAt
};
await storage.updateAISpan({
spanId: span.spanId,
traceId: span.traceId,
updates: { links: [...existingLinks, link] }
});
}
var scoreTracesWorkflow = createWorkflow({
id: "__batch-scoring-traces",
inputSchema: z.object({
targets: z.array(
z.object({
traceId: z.string(),
spanId: z.string().optional()
})
),
scorerName: z.string()
}),
outputSchema: z.any(),
steps: [getTraceStep],
options: {
tracingPolicy: {
internal: 15 /* ALL */
}
}
});
scoreTracesWorkflow.then(getTraceStep).commit();
export { scoreTraces, scoreTracesWorkflow };
//# sourceMappingURL=index.js.map
//# sourceMappingURL=index.js.map