UNPKG

dd-trace

Version:

Datadog APM tracing client for JavaScript

702 lines (652 loc) 23.2 kB
'use strict' const log = require('../../../../dd-trace/src/log') const MODEL_TYPE_IDENTIFIERS = [ 'foundation-model/', 'custom-model/', 'provisioned-model/', 'imported-module/', 'prompt/', 'endpoint/', 'inference-profile/', 'default-prompt-router/', ] const PROVIDER = { AI21: 'AI21', AMAZON: 'AMAZON', ANTHROPIC: 'ANTHROPIC', COHERE: 'COHERE', META: 'META', STABILITY: 'STABILITY', MISTRAL: 'MISTRAL', } /** * Coerce the chunks into a single response body. * * @param {Array<{ chunk: { bytes: Buffer } }>} chunks * @param {string} modelProvider * @param {string} modelName * @returns {Generation | Record<never, never>} */ function extractTextAndResponseReasonFromStream (chunks, modelProvider, modelName) { const modelProviderUpper = modelProvider.toUpperCase() // streaming unsupported for AMAZON embedding models, COHERE embedding models, STABILITY if ( (modelProviderUpper === PROVIDER.AMAZON && modelName.includes('embed')) || (modelProviderUpper === PROVIDER.COHERE && modelName.includes('embed')) || modelProviderUpper === PROVIDER.STABILITY ) { return {} } let message = '' let inputTokens = 0 let outputTokens = 0 let cacheReadTokens = 0 let cacheWriteTokens = 0 for (const { chunk: { bytes } } of chunks) { const body = JSON.parse(Buffer.from(bytes).toString('utf8')) switch (modelProviderUpper) { case PROVIDER.AMAZON: { if (body?.outputText) { message += body?.outputText inputTokens = body?.inputTextTokenCount outputTokens = body?.totalOutputTextTokenCount } else if (body?.contentBlockDelta?.delta?.text) { message += body.contentBlockDelta.delta.text } break } case PROVIDER.AI21: { const content = body?.choices?.[0]?.delta?.content if (content) { message += content } break } case PROVIDER.ANTHROPIC: { if (body.completion) { message += body.completion } else if (body.delta?.text) { message += body.delta.text } if (body.message?.usage?.input_tokens) inputTokens = body.message.usage.input_tokens if (body.message?.usage?.output_tokens) outputTokens = body.message.usage.output_tokens break } case PROVIDER.COHERE: { if (body?.event_type === 'stream-end') { message = body.response?.text } break } case PROVIDER.META: { message += body?.generation break } case PROVIDER.MISTRAL: { message += body?.outputs?.[0]?.text break } } // by default, it seems newer versions of the AWS SDK include the input/output token counts in the response body const invocationMetrics = body['amazon-bedrock-invocationMetrics'] if (invocationMetrics) { inputTokens = invocationMetrics.inputTokenCount outputTokens = invocationMetrics.outputTokenCount cacheReadTokens = invocationMetrics.cacheReadInputTokenCount cacheWriteTokens = invocationMetrics.cacheWriteInputTokenCount } } return new Generation({ message, role: 'assistant', inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens, }) } class Generation { constructor ({ message = '', finishReason = '', choiceId = '', role, inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens, messages, } = {}) { // stringify message as it could be a single generated message as well as a list of embeddings this.message = typeof message === 'string' ? message : JSON.stringify(message) || '' this.finishReason = finishReason || '' this.choiceId = choiceId || undefined this.role = role this.usage = { inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens, } this.messages = messages ?? [{ content: this.message, role: this.role }] } } class RequestParams { constructor ({ prompt = '', temperature, topP, topK, maxTokens, stopSequences = [], inputType = '', truncate = '', stream = '', n, } = {}) { this.prompt = prompt this.temperature = temperature this.topP = topP this.topK = topK this.maxTokens = maxTokens this.stopSequences = stopSequences || [] this.inputType = inputType || '' this.truncate = truncate || '' this.stream = stream || '' this.n = n } } function parseModelId (modelId) { // Best effort to extract the model provider and model name from the bedrock model ID. // modelId can be a 1/2 period-separated string or a full AWS ARN, based on the following formats: // 1. Base model: "{model_provider}.{model_name}" // 2. Cross-region model: "{region}.{model_provider}.{model_name}" // 3. Other: Prefixed by AWS ARN "arn:aws{+region?}:bedrock:{region}:{account-id}:" // a. Foundation model: ARN prefix + "foundation-model/{region?}.{model_provider}.{model_name}" // b. Custom model: ARN prefix + "custom-model/{model_provider}.{model_name}" // c. Provisioned model: ARN prefix + "provisioned-model/{model-id}" // d. Imported model: ARN prefix + "imported-module/{model-id}" // e. Prompt management: ARN prefix + "prompt/{prompt-id}" // f. Sagemaker: ARN prefix + "endpoint/{model-id}" // g. Inference profile: ARN prefix + "{application-?}inference-profile/{model-id}" // h. Default prompt router: ARN prefix + "default-prompt-router/{prompt-id}" // If model provider cannot be inferred from the modelId formatting, then default to "custom" modelId = modelId.toLowerCase() if (!modelId.startsWith('arn:aws')) { const modelMeta = modelId.split('.') if (modelMeta.length < 2) { return { modelProvider: 'custom', modelName: modelMeta[0] } } return { modelProvider: modelMeta.at(-2), modelName: modelMeta.at(-1) } } for (const identifier of MODEL_TYPE_IDENTIFIERS) { if (!modelId.includes(identifier)) { continue } modelId = modelId.split(identifier).pop() if (['foundation-model/', 'custom-model/'].includes(identifier)) { const modelMeta = modelId.split('.') if (modelMeta.length < 2) { return { modelProvider: 'custom', modelName: modelId } } return { modelProvider: modelMeta.at(-2), modelName: modelMeta.at(-1) } } return { modelProvider: 'custom', modelName: modelId } } return { modelProvider: 'custom', modelName: 'custom' } } function extractRequestParams (params, provider) { const requestBody = JSON.parse(params.body) const modelId = params.modelId switch (provider.toUpperCase()) { case PROVIDER.AI21: { let userPrompt = requestBody.prompt if (modelId.includes('jamba')) { for (const message of requestBody.messages) { if (message.role === 'user') { userPrompt = message.content // Return the content of the most recent user message } } } return new RequestParams({ prompt: userPrompt, temperature: requestBody.temperature, topP: requestBody.top_p, maxTokens: requestBody.max_tokens, stopSequences: requestBody.stop_sequences, }) } case PROVIDER.AMAZON: { const prompt = requestBody.inputText if (modelId.includes('embed')) { return new RequestParams({ prompt }) } else if (prompt !== undefined) { const textGenerationConfig = requestBody.textGenerationConfig || {} return new RequestParams({ prompt, temperature: textGenerationConfig.temperature, topP: textGenerationConfig.topP, maxTokens: textGenerationConfig.maxTokenCount, stopSequences: textGenerationConfig.stopSequences, }) } else if (Array.isArray(requestBody.messages)) { const inferenceConfig = requestBody.inferenceConfig || {} const messages = [] if (Array.isArray(requestBody.system)) { for (const sysMsg of requestBody.system) { messages.push({ content: sysMsg.text, role: 'system', }) } } for (const message of requestBody.messages) { const textBlocks = message.content?.filter(block => block.text) || [] if (textBlocks.length > 0) { messages.push({ content: textBlocks.map(block => block.text).join(''), role: message.role, }) } } return new RequestParams({ prompt: messages, temperature: inferenceConfig.temperature, topP: inferenceConfig.topP, maxTokens: inferenceConfig.maxTokens, stopSequences: inferenceConfig.stopSequences, }) } return new RequestParams({ prompt }) } case PROVIDER.ANTHROPIC: { let prompt = requestBody.prompt if (Array.isArray(requestBody.messages)) { // newer claude models for (let idx = requestBody.messages.length - 1; idx >= 0; idx--) { const message = requestBody.messages[idx] if (message.role === 'user') { prompt = message.content?.filter(block => block.type === 'text') .map(block => block.text) .join('') break } } } return new RequestParams({ prompt, temperature: requestBody.temperature, topP: requestBody.top_p, maxTokens: requestBody.max_tokens_to_sample ?? requestBody.max_tokens, stopSequences: requestBody.stop_sequences, }) } case PROVIDER.COHERE: { if (modelId.includes('embed')) { return new RequestParams({ prompt: requestBody.texts, inputType: requestBody.input_type, truncate: requestBody.truncate, }) } return new RequestParams({ prompt: requestBody.prompt, temperature: requestBody.temperature, topP: requestBody.p, maxTokens: requestBody.max_tokens, stopSequences: requestBody.stop_sequences, stream: requestBody.stream, n: requestBody.num_generations, }) } case PROVIDER.META: { return new RequestParams({ prompt: requestBody.prompt, temperature: requestBody.temperature, topP: requestBody.top_p, maxTokens: requestBody.max_gen_len, }) } case PROVIDER.MISTRAL: { return new RequestParams({ prompt: requestBody.prompt, temperature: requestBody.temperature, topP: requestBody.top_p, maxTokens: requestBody.max_tokens, stopSequences: requestBody.stop, topK: requestBody.top_k, }) } case PROVIDER.STABILITY: { return new RequestParams() } default: { return new RequestParams() } } } function extractTextAndResponseReason (response, provider, modelName) { const body = JSON.parse(Buffer.from(response.body).toString('utf8')) const shouldSetChoiceIds = provider.toUpperCase() === PROVIDER.COHERE && !modelName.includes('embed') try { switch (provider.toUpperCase()) { case PROVIDER.AI21: { if (modelName.includes('jamba')) { const generations = body.choices || [] if (generations.length > 0) { const generation = generations[0] return new Generation({ message: generation.message.content, finishReason: generation.finish_reason, choiceId: shouldSetChoiceIds ? generation.id : undefined, role: generation.message.role, inputTokens: body.usage?.prompt_tokens, outputTokens: body.usage?.completion_tokens, }) } } const completions = body.completions || [] if (completions.length > 0) { const completion = completions[0] return new Generation({ message: completion.data?.text, finishReason: completion?.finishReason, choiceId: shouldSetChoiceIds ? completion?.id : undefined, inputTokens: body.usage?.prompt_tokens, outputTokens: body.usage?.completion_tokens, }) } return new Generation() } case PROVIDER.AMAZON: { if (modelName.includes('embed')) { return new Generation({ message: body.embedding }) } if (body.results) { const results = body.results || [] if (results.length > 0) { const result = results[0] return new Generation({ message: result.outputText, finishReason: result.completionReason, inputTokens: body.inputTextTokenCount, outputTokens: result.tokenCount, }) } } else if (body.output) { const output = body.output || {} return new Generation({ message: output.message?.content[0]?.text ?? 'Unsupported content type', finishReason: body.stopReason, role: output.message?.role, ...buildUsage(body.usage), }) } break } case PROVIDER.ANTHROPIC: { let message = body.completion if (Array.isArray(body.content)) { // newer claude models message = body.content.find(item => item.type === 'text')?.text ?? body.content } else if (body.content) { message = body.content } return new Generation({ message, finishReason: body.stop_reason }) } case PROVIDER.COHERE: { if (modelName.includes('embed')) { const embeddings = body.embeddings || [[]] if (embeddings.length > 0) { return new Generation({ message: embeddings[0] }) } } if (body.text) { return new Generation({ message: body.text, finishReason: body.finish_reason, choiceId: shouldSetChoiceIds ? body.response_id : undefined, }) } const generations = body.generations || [] if (generations.length > 0) { const generation = generations[0] return new Generation({ message: generation.text, finishReason: generation.finish_reason, choiceId: shouldSetChoiceIds ? generation.id : undefined, }) } break } case PROVIDER.META: { return new Generation({ message: body.generation, finishReason: body.stop_reason, inputTokens: body.prompt_token_count, outputTokens: body.generation_token_count, }) } case PROVIDER.MISTRAL: { const mistralGenerations = body.outputs || [] if (mistralGenerations.length > 0) { const generation = mistralGenerations[0] return new Generation({ message: generation.text, finishReason: generation.stop_reason }) } break } case PROVIDER.STABILITY: { return new Generation() } default: { return new Generation() } } } catch { log.warn('Unable to extract text/finishReason from response body. Defaulting to empty text/finishReason.') return new Generation() } return new Generation() } /** * Convert a Converse content-block array to an LLMObs message array. * * @param {string} role * @param {Array<object>} contentBlocks * @returns {{ content?: string, role: string, toolCalls?: Array, toolResults?: Array } | undefined} */ function extractMessagesFromConverseContent (role, contentBlocks) { let content = '' const toolCalls = [] const toolResults = [] for (const block of contentBlocks || []) { if (block == null || typeof block !== 'object') continue if (typeof block.text === 'string') { content += block.text } else if (block.toolUse) { toolCalls.push(buildToolCall(block.toolUse)) } else if (block.toolResult) { toolResults.push(buildToolResult(block.toolResult)) } else { content += `[Unsupported content type: ${getContentBlockType(block)}]` } } if (!content && toolCalls.length === 0 && toolResults.length === 0) return const message = { role } if (content) message.content = content if (toolCalls.length > 0) message.toolCalls = toolCalls if (toolResults.length > 0) message.toolResults = toolResults return message } /** * Resolve a Converse `ContentBlock`'s member type. The block is a key-presence * tagged union (no `type` discriminator), so the active member is its sole own * key. For forward-compat `$unknown` members the real type is the first element * of the `[name, value]` tuple. * * @param {object} block * @returns {string} */ function getContentBlockType (block) { const key = Object.keys(block)[0] if (key === '$unknown') return block.$unknown?.[0] ?? 'unknown' return key ?? 'unknown' } // Always emit at least one output message so downstream tagging has a role to attach to. function toOutputMessages (role, contentBlocks) { const message = extractMessagesFromConverseContent(role, contentBlocks) return message ? [message] : [{ role, content: '' }] } function buildToolCall ({ name, input, toolUseId }) { return { name: name ?? '', arguments: input ?? {}, toolId: toolUseId ?? '', type: 'toolUse' } } function parseToolInput (inputStr) { try { return JSON.parse(inputStr) } catch { log.warn('Failed to parse Converse stream toolUse.input JSON; emitting empty arguments') return {} } } function buildToolResult ({ toolUseId, content }) { const result = (content || []).map(resolveToolResultItem).join('') return { name: '', result, toolId: toolUseId ?? '', type: 'tool_result' } } function resolveToolResultItem (item) { if (typeof item.text === 'string') return item.text if (item.json != null) return JSON.stringify(item.json) return `[Unsupported content type(s): ${getContentBlockType(item)}]` } function buildUsage (usage = {}) { return { inputTokens: usage.inputTokens, outputTokens: usage.outputTokens, cacheReadTokens: usage.cacheReadInputTokens ?? usage.cacheReadInputTokenCount, cacheWriteTokens: usage.cacheWriteInputTokens ?? usage.cacheWriteInputTokenCount, } } /** * Extract tool definitions from a Converse request's `toolConfig.tools`, * mapping Bedrock's `toolSpec` shape to LLMObs `ToolDefinition` shape. * * @param {object} params - Converse request params with optional `toolConfig.tools[].toolSpec`. * @returns {Array<{ name: string, description: string, schema: object }>} */ function extractConverseToolDefinitions (params) { const toolDefinitions = [] for (const tool of params.toolConfig?.tools || []) { const toolSpec = tool?.toolSpec if (!toolSpec?.name) continue toolDefinitions.push({ name: toolSpec.name, description: toolSpec.description ?? '', schema: toolSpec.inputSchema ?? {}, }) } return toolDefinitions } /** * Extract request metadata + rendered input messages from a Converse / * ConverseStream request. * * @param {{ modelId?: string, messages?: Array, system?: Array, inferenceConfig?: object, toolConfig?: object }} params * @returns {RequestParams} */ function extractRequestParamsConverse (params) { const prompt = [] for (const block of params.system || []) { if (typeof block?.text === 'string') prompt.push({ content: block.text, role: 'system' }) } for (const msg of params.messages || []) { if (msg == null || typeof msg !== 'object') continue const message = extractMessagesFromConverseContent(msg.role || 'user', msg.content) if (message) prompt.push(message) } const { temperature, topP, maxTokens, stopSequences } = params.inferenceConfig || {} return new RequestParams({ prompt, temperature, topP, maxTokens, stopSequences }) } /** * Extract output messages + usage from a non-stream Converse response. * * @param {{ output?: { message?: { role?: string, content?: Array } }, stopReason?: string, usage?: object }} response * @returns {Generation} */ function extractTextAndResponseReasonConverse (response) { const outputMessage = response?.output?.message const role = outputMessage?.role || 'assistant' return new Generation({ role, finishReason: response?.stopReason || '', ...buildUsage(response?.usage), messages: toOutputMessages(role, outputMessage?.content), }) } /** * Aggregate Converse stream events into a single output message + usage. * One messageStart / messageStop pair per response, so one message out. * * Stream events describe the same content-block structure as the non-stream * response, spread across start/delta chunks. We reassemble those chunks * into a normalized content-block array and reuse the non-stream extractor. * * @param {Array<object>} chunks - Ordered ConverseStreamOutput events. * @returns {Generation} */ function extractTextAndResponseReasonConverseFromStream (chunks) { let role = 'assistant' let stopReason = '' let usage = {} const blocksByIdx = new Map() for (const chunk of chunks || []) { if (chunk.messageStart?.role) { role = chunk.messageStart.role } else if (chunk.messageStop?.stopReason) { stopReason = chunk.messageStop.stopReason } else if (chunk.metadata?.usage) { usage = chunk.metadata.usage } else if (chunk.contentBlockStart?.start?.toolUse) { const { contentBlockIndex, start: { toolUse } } = chunk.contentBlockStart blocksByIdx.set(contentBlockIndex, { toolUse: { toolUseId: toolUse.toolUseId, name: toolUse.name, inputStr: '' }, }) } else if (chunk.contentBlockDelta) { const { contentBlockIndex, delta } = chunk.contentBlockDelta if (typeof delta?.text === 'string') { const block = blocksByIdx.get(contentBlockIndex) ?? {} block.text = (block.text ?? '') + delta.text blocksByIdx.set(contentBlockIndex, block) } else if (typeof delta?.toolUse?.input === 'string') { const block = blocksByIdx.get(contentBlockIndex) ?? { toolUse: { inputStr: '' } } block.toolUse ??= { inputStr: '' } block.toolUse.inputStr += delta.toolUse.input blocksByIdx.set(contentBlockIndex, block) } } } const contentBlocks = [...blocksByIdx.keys()].sort((a, b) => a - b).map(i => { const block = blocksByIdx.get(i) if (block.toolUse) { const { toolUseId, name, inputStr } = block.toolUse block.toolUse = { toolUseId, name, input: parseToolInput(inputStr) } } return block }) return new Generation({ role, finishReason: stopReason, ...buildUsage(usage), messages: toOutputMessages(role, contentBlocks), }) } module.exports = { Generation, RequestParams, extractTextAndResponseReasonFromStream, parseModelId, extractRequestParams, extractTextAndResponseReason, extractMessagesFromConverseContent, extractConverseToolDefinitions, extractRequestParamsConverse, extractTextAndResponseReasonConverse, extractTextAndResponseReasonConverseFromStream, PROVIDER, }