UNPKG

langsmith

Version:

Client library to connect to the LangSmith Observability and Evaluation Platform.

295 lines (294 loc) 12.9 kB
import { getCurrentRunTree, traceable } from "../../traceable.js"; import { extractInputTokenDetails, extractOutputTokenDetails, } from "../../utils/vercel.js"; import { convertMessageToTracedFormat } from "./utils.js"; // eslint-disable-next-line @typescript-eslint/no-explicit-any const _formatTracedInputs = (params) => { const { prompt, ...rest } = params; if (prompt == null) { return params; } if (Array.isArray(prompt)) { return { ...rest, messages: prompt.map((message) => convertMessageToTracedFormat(message)), }; } return rest; }; const _formatTracedOutputs = (outputs, includeHttpDetails = false) => { let formattedOutputs; if (includeHttpDetails) { // Include all fields including raw request/response/usage formattedOutputs = { ...outputs }; } else { // Extract only the fields we want to trace, excluding raw request/response/usage const { request: _, response: __, ...messageFields } = outputs; formattedOutputs = { ...messageFields }; } if (formattedOutputs.role == null) { formattedOutputs.role = formattedOutputs.type ?? "assistant"; } return convertMessageToTracedFormat(formattedOutputs); }; const setUsageMetadataOnRunTree = (result, runTree) => { if (result.usage == null || typeof result.usage !== "object") { return; } const usage = result.usage; let inputTokens; let outputTokens; let totalTokens; // AI SDK 6: Check for object-based token structures first if (typeof usage.inputTokens === "object" && usage.inputTokens?.total != null) { // AI SDK 6 detected inputTokens = usage.inputTokens.total; if (typeof usage.outputTokens === "object" && usage.outputTokens?.total != null) { outputTokens = usage.outputTokens.total; } totalTokens = result.usage?.totalTokens; if (typeof totalTokens !== "number" && typeof inputTokens === "number" && typeof outputTokens === "number") { totalTokens = inputTokens + outputTokens; } } else if (typeof usage.inputTokens === "number") { // AI SDK 5 detected inputTokens = usage.inputTokens; if (typeof usage.outputTokens === "number") { outputTokens = usage.outputTokens; } totalTokens = result.usage?.totalTokens; if (typeof totalTokens !== "number" && typeof inputTokens === "number" && typeof outputTokens === "number") { totalTokens = inputTokens + outputTokens; } } else { // AI SDK 4 fallback if (typeof usage.promptTokens === "number") { inputTokens = usage.promptTokens; } if (typeof usage.completionTokens === "number") { outputTokens = usage.completionTokens; } totalTokens = result.usage?.totalTokens; if (typeof totalTokens !== "number" && typeof inputTokens === "number" && typeof outputTokens === "number") { totalTokens = inputTokens + outputTokens; } } const langsmithUsage = { input_tokens: inputTokens, output_tokens: outputTokens, total_tokens: totalTokens, }; const inputTokenDetails = extractInputTokenDetails(result.usage, result.providerMetadata); const outputTokenDetails = extractOutputTokenDetails(result.usage, result.providerMetadata); runTree.extra = { ...runTree.extra, metadata: { ...runTree.extra?.metadata, usage_metadata: { ...langsmithUsage, input_token_details: { ...inputTokenDetails, }, output_token_details: { ...outputTokenDetails, }, }, }, }; }; /** * AI SDK middleware that wraps an AI SDK 6 or 5 model and adds LangSmith tracing. */ export function LangSmithMiddleware(config) { const { name, modelId, lsConfig } = config ?? {}; return { wrapGenerate: async ({ doGenerate, params }) => { const traceableFunc = traceable(async (_params) => { const result = await doGenerate(); const currentRunTree = getCurrentRunTree(true); if (currentRunTree !== undefined) { setUsageMetadataOnRunTree(result, currentRunTree); } return result; }, { ...lsConfig, name: name ?? "ai.doGenerate", run_type: "llm", metadata: { ls_model_name: modelId, ai_sdk_method: "ai.doGenerate", ...lsConfig?.metadata, }, processInputs: (inputs) => { const typedInputs = inputs; const inputFormatter = lsConfig?.processInputs ?? _formatTracedInputs; return inputFormatter(typedInputs); }, processOutputs: (outputs) => { const typedOutputs = outputs; if (lsConfig?.processOutputs) { return lsConfig.processOutputs(typedOutputs); } return _formatTracedOutputs(typedOutputs, lsConfig?.traceRawHttp); }, }); const res = await traceableFunc(params); return res; }, wrapStream: async ({ doStream, params }) => { const parentRunTree = getCurrentRunTree(true); let runTree; if (parentRunTree != null && typeof parentRunTree === "object" && typeof parentRunTree.createChild === "function") { const inputFormatter = lsConfig?.processInputs ?? _formatTracedInputs; const formattedInputs = inputFormatter(params); runTree = parentRunTree?.createChild({ ...lsConfig, name: name ?? "ai.doStream", run_type: "llm", metadata: { ls_model_name: modelId, ai_sdk_method: "ai.doStream", ...lsConfig?.metadata, }, inputs: formattedInputs, }); } await runTree?.postRun(); try { const { stream, ...rest } = await doStream(); const chunks = []; const transformStream = new TransformStream({ async transform(chunk, controller) { if (chunk.type === "tool-input-start" || chunk.type === "text-start") { // Only necessary to log the first token event if (runTree?.events == null || (Array.isArray(runTree.events) && runTree.events.length === 0)) { runTree?.addEvent({ name: "new_token" }); } } else if (chunk.type === "finish") { runTree?.addEvent({ name: "end" }); } chunks.push(chunk); controller.enqueue(chunk); }, async flush() { try { const output = chunks.reduce((aggregated, chunk) => { if (chunk.type === "text-delta") { if (chunk.delta != null) { return { ...aggregated, content: aggregated.content + chunk.delta, }; } else if ("textDelta" in chunk && chunk.textDelta != null) { // AI SDK 4 shim return { ...aggregated, content: aggregated.content + chunk.textDelta, }; } else { return aggregated; } } else if (chunk.type === "tool-call") { const matchingToolCall = aggregated.tool_calls.find((call) => call.id === chunk.toolCallId); if (matchingToolCall != null) { return aggregated; } let chunkArgs = chunk.input; if (chunkArgs == null && "args" in chunk && typeof chunk.args === "string") { chunkArgs = chunk.args; } return { ...aggregated, tool_calls: [ ...aggregated.tool_calls, { id: chunk.toolCallId, type: "function", function: { name: chunk.toolName, arguments: chunkArgs, }, }, ], }; } else if (chunk.type === "finish") { if (runTree != null) { setUsageMetadataOnRunTree(chunk, runTree); } return { ...aggregated, providerMetadata: chunk.providerMetadata, finishReason: chunk.finishReason, }; } else { return aggregated; } }, { content: "", role: "assistant", tool_calls: [], }); // Add raw request/response for tracing only (not part of aggregated output) const outputForTracing = { ...output, request: rest.request, response: rest.response, }; let formattedOutputs; if (lsConfig?.processOutputs) { formattedOutputs = await lsConfig.processOutputs(outputForTracing); } else { formattedOutputs = _formatTracedOutputs(outputForTracing, lsConfig?.traceRawHttp); } await runTree?.end(formattedOutputs); } catch (error) { await runTree?.end(undefined, error.message ?? String(error)); throw error; } finally { await runTree?.patchRun({ excludeInputs: true, }); } }, }); return { stream: stream.pipeThrough(transformStream), ...rest, }; } catch (error) { await runTree?.end(undefined, error.message ?? String(error)); await runTree?.patchRun({ excludeInputs: true, }); throw error; } }, }; }