UNPKG

langsmith

Version:

Client library to connect to the LangSmith Observability and Evaluation Platform.

238 lines (237 loc) 11.5 kB
import { getCurrentRunTree, traceable } from "../../traceable.js"; import { convertMessageToTracedFormat, setUsageMetadataOnRunTree, } from "./utils.js"; // eslint-disable-next-line @typescript-eslint/no-explicit-any const _formatTracedInputs = (params) => { const { prompt, ...rest } = params; if (prompt == null) { return params; } if (Array.isArray(prompt)) { return { ...rest, messages: prompt.map((message) => convertMessageToTracedFormat(message)), }; } return rest; }; const _formatTracedOutputs = (outputs, includeHttpDetails = false) => { let formattedOutputs; if (includeHttpDetails) { // Include all fields including raw request/response/usage formattedOutputs = { ...outputs }; } else { // Extract only the fields we want to trace, excluding raw request/response/usage const { request: _, response: __, ...messageFields } = outputs; formattedOutputs = { ...messageFields }; } if (formattedOutputs.role == null) { formattedOutputs.role = formattedOutputs.type ?? "assistant"; } return convertMessageToTracedFormat(formattedOutputs); }; /** * AI SDK middleware that wraps an AI SDK 6 or 5 model and adds LangSmith tracing. */ export function LangSmithMiddleware(config) { const { name, modelId, lsConfig } = config ?? {}; return { wrapGenerate: async ({ doGenerate, params }) => { const traceableFunc = traceable(async (_params) => { const result = await doGenerate(); const currentRunTree = getCurrentRunTree(true); if (currentRunTree !== undefined) { setUsageMetadataOnRunTree(result, currentRunTree); } return result; }, { ...lsConfig, name: name ?? "ai.doGenerate", run_type: "llm", metadata: { ls_integration: "vercel-ai-sdk", ls_model_name: modelId, ai_sdk_method: "ai.doGenerate", ...lsConfig?.metadata, }, processInputs: (inputs) => { const typedInputs = inputs; const inputFormatter = lsConfig?.processInputs ?? _formatTracedInputs; return inputFormatter(typedInputs); }, processOutputs: (outputs) => { const typedOutputs = outputs; if (lsConfig?.processOutputs) { return lsConfig.processOutputs(typedOutputs); } return _formatTracedOutputs(typedOutputs, lsConfig?.traceRawHttp); }, }); const res = await traceableFunc(params); return res; }, wrapStream: async ({ doStream, params }) => { const parentRunTree = getCurrentRunTree(true); let runTree; if (parentRunTree != null && typeof parentRunTree === "object" && typeof parentRunTree.createChild === "function") { const inputFormatter = lsConfig?.processInputs ?? _formatTracedInputs; const formattedInputs = inputFormatter(params); runTree = parentRunTree?.createChild({ ...lsConfig, name: name ?? "ai.doStream", run_type: "llm", metadata: { ls_integration: "vercel-ai-sdk", ls_model_name: modelId, ai_sdk_method: "ai.doStream", ...lsConfig?.metadata, }, inputs: formattedInputs, }); } await runTree?.postRun(); try { const { stream, ...rest } = await doStream(); const chunks = []; const transformStream = new TransformStream({ async transform(chunk, controller) { if (chunk.type === "tool-input-start" || chunk.type === "text-start") { // Only necessary to log the first token event if (runTree?.events == null || (Array.isArray(runTree.events) && runTree.events.length === 0)) { runTree?.addEvent({ name: "new_token" }); } } else if (chunk.type === "finish") { runTree?.addEvent({ name: "end" }); } chunks.push(chunk); controller.enqueue(chunk); }, async flush() { try { const output = chunks.reduce((aggregated, chunk) => { if (chunk.type === "text-delta") { if (aggregated.content.at(-1)?.type !== "text") { aggregated.content.push({ type: "text", text: "" }); } const contentBlock = aggregated.content.at(-1); if (chunk.delta != null) { contentBlock.text += chunk.delta; } else if ("textDelta" in chunk && chunk.textDelta != null) { // AI SDK 4 shim contentBlock.text += chunk.textDelta; } return aggregated; } else if (chunk.type === "reasoning-delta") { if (aggregated.content.at(-1)?.type !== "reasoning") { aggregated.content.push({ type: "reasoning", reasoning: "", extras: chunk.providerMetadata, }); } const reasoningBlock = aggregated.content.at(-1); if (chunk.delta != null) { reasoningBlock.reasoning += chunk.delta; } return aggregated; } else if (chunk.type === "tool-call") { const matchingToolCall = aggregated.tool_calls.find((call) => call.id === chunk.toolCallId); if (matchingToolCall != null) { return aggregated; } let chunkArgs = chunk.input; if (chunkArgs == null && "args" in chunk && typeof chunk.args === "string") { chunkArgs = chunk.args; } return { ...aggregated, tool_calls: [ ...aggregated.tool_calls, { id: chunk.toolCallId, type: "function", function: { name: chunk.toolName, arguments: chunkArgs, }, }, ], }; } else if (chunk.type === "finish") { if (runTree != null) { setUsageMetadataOnRunTree(chunk, runTree); } return { ...aggregated, providerMetadata: chunk.providerMetadata, finishReason: chunk.finishReason, }; } else { return aggregated; } }, { content: [], role: "assistant", tool_calls: [], }); // Add raw request/response for tracing only (not part of aggregated output) const outputForTracing = { ...output, request: rest.request, response: rest.response, }; if ("content" in outputForTracing && Array.isArray(outputForTracing.content) && outputForTracing.content.length === 1 && outputForTracing.content[0].type === "text") { outputForTracing.content = outputForTracing.content[0].text; } let formattedOutputs; if (lsConfig?.processOutputs) { formattedOutputs = await lsConfig.processOutputs(outputForTracing); } else { formattedOutputs = _formatTracedOutputs(outputForTracing, lsConfig?.traceRawHttp); } await runTree?.end(formattedOutputs); } catch (error) { await runTree?.end(undefined, error.message ?? String(error)); throw error; } finally { await runTree?.patchRun({ excludeInputs: true, }); } }, }); return { stream: stream.pipeThrough(transformStream), ...rest, }; } catch (error) { await runTree?.end(undefined, error.message ?? String(error)); await runTree?.patchRun({ excludeInputs: true, }); throw error; } }, }; }