UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

1,046 lines 68.7 kB
import { Bedrock, CreateModelCustomizationJobCommand, GetModelCustomizationJobCommand, ModelCustomizationJobStatus, ModelModality, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock"; import { BedrockRuntime } from "@aws-sdk/client-bedrock-runtime"; import { S3Client } from "@aws-sdk/client-s3"; import { AbstractDriver, deserializeBinaryFromStorage, getConversationMeta, getMaxTokensLimitBedrock, getModelCapabilities, incrementConversationTurn, isClaudeVersionGTE, LlumiverseError, modelModalitiesToArray, stripBinaryFromConversation, stripHeartbeatsFromConversation, TrainingJobStatus, truncateLargeTextInConversation } from "@llumiverse/core"; import { transformAsyncIterator } from "@llumiverse/core/async"; import { formatNovaPrompt } from "@llumiverse/core/formatters"; import { LRUCache } from "mnemonist"; import { resolveClaudeThinking } from "../shared/claude-thinking.js"; import { converseConcatMessages, converseJSONprefill, converseSystemToMessages, formatConversePrompt } from "./converse.js"; import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js"; import { forceUploadFile } from "./s3.js"; import { formatTwelvelabsPegasusPrompt } from "./twelvelabs.js"; const supportStreamingCache = new LRUCache(4096); var BedrockModelType; (function (BedrockModelType) { BedrockModelType["FoundationModel"] = "foundation-model"; BedrockModelType["InferenceProfile"] = "inference-profile"; BedrockModelType["CustomModel"] = "custom-model"; BedrockModelType["Unknown"] = "unknown"; })(BedrockModelType || (BedrockModelType = {})); ; function converseFinishReason(reason) { //Possible values: //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered if (!reason) return undefined; switch (reason) { case 'end_turn': return "stop"; case 'max_tokens': return "length"; default: return reason; } } //Used to get a max_token value when not specified in the model options. Claude requires it to be set. function maxTokenFallbackClaude(option) { const modelOptions = option.model_options; if (modelOptions && typeof modelOptions.max_tokens === "number") { return modelOptions.max_tokens; } else { let maxSupportedTokens = getMaxTokensLimitBedrock(option.model) ?? 8192; // Should always return a number for claude, 8192 is to satisfy the TypeScript type checker; // Fallback to the default max tokens limit for the model if (option.model.includes('claude-3-7-sonnet') && (modelOptions?.thinking_budget_tokens ?? 0) < 48000) { maxSupportedTokens = 64000; // Claude 3.7 can go up to 128k with a beta header, but when no max tokens is specified, we default to 64k. } return maxSupportedTokens; } } export class BedrockDriver extends AbstractDriver { static PROVIDER = "bedrock"; provider = BedrockDriver.PROVIDER; _executor; _service; _service_region; constructor(options) { super(options); if (!options.region) { throw new Error("No region found. Set the region in the environment's endpoint URL."); } } getExecutor() { if (!this._executor) { this._executor = new BedrockRuntime({ region: this.options.region, credentials: this.options.credentials, }); } return this._executor; } getService(region = this.options.region) { if (!this._service || this._service_region !== region) { this._service = new Bedrock({ region: region, credentials: this.options.credentials, }); this._service_region = region; } return this._service; } async formatPrompt(segments, opts) { if (opts.model.includes("canvas")) { return await formatNovaPrompt(segments, opts.result_schema); } if (opts.model.includes("twelvelabs.pegasus")) { return await formatTwelvelabsPegasusPrompt(segments, opts); } return await formatConversePrompt(segments, opts); } /** * Format AWS Bedrock errors into LlumiverseError with proper status codes and retryability. * * AWS SDK errors provide: * - error.name: The exception type (e.g., "ThrottlingException") * - error.$metadata.httpStatusCode: The HTTP status code * - error.$metadata.requestId: The AWS request ID for tracking * - error.$fault: "client" or "server" indicating error category * * @param error - The AWS SDK error * @param context - Context about where the error occurred * @returns A standardized LlumiverseError */ formatLlumiverseError(error, context) { // Check if it's an AWS SDK error with $metadata const awsError = error; const hasMetadata = awsError?.$metadata !== undefined; if (!hasMetadata) { // Not an AWS SDK error, use default handling return super.formatLlumiverseError(error, context); } // Extract AWS-specific fields const errorName = awsError.name || 'UnknownError'; const httpStatusCode = awsError.$metadata?.httpStatusCode; const requestId = awsError.$metadata?.requestId; const fault = awsError.$fault; // "client" or "server" // Extract error message - handle both Error instances and plain objects let message; if (error instanceof Error) { message = error.message; } else if (typeof awsError.message === 'string') { message = awsError.message; } else { message = String(error); } // Build user-facing message with error name and status code let userMessage = message; // Include status code in message if available (for end-user visibility) if (httpStatusCode) { userMessage = `[${httpStatusCode}] ${userMessage}`; } // Prefix with error name if it's meaningful (not just "Error") if (errorName && errorName !== 'Error' && errorName !== 'UnknownError') { userMessage = `${errorName}: ${userMessage}`; } // Add request ID if available (useful for AWS support) if (requestId) { userMessage += ` (Request ID: ${requestId})`; } // Determine retryability based on AWS error types const retryable = this.isBedrockErrorRetryable(errorName, httpStatusCode, fault); return new LlumiverseError(`[${this.provider}] ${userMessage}`, retryable, context, error, httpStatusCode, // Only set code if we have numeric status code errorName // Preserve AWS error name ); } /** * Determine if a Bedrock error is retryable based on error type and status. * * Retryable errors: * - ThrottlingException: Rate limit exceeded, retry with backoff * - ServiceUnavailableException: Service temporarily down * - InternalServerException: Server-side error * - ServiceQuotaExceededException: Quota exhausted, may recover * - 5xx status codes: Server errors * - 429, 408 status codes: Rate limit, timeout * * Non-retryable errors: * - ValidationException: Invalid request parameters * - AccessDeniedException: Authentication/authorization failure * - ResourceNotFoundException: Resource doesn't exist * - ConflictException: Resource state conflict * - ResourceInUseException: Resource locked by another operation * - 4xx status codes (except 429, 408): Client errors * * @param errorName - The AWS error name (e.g., "ThrottlingException") * @param httpStatusCode - The HTTP status code if available * @param fault - The fault type ("client" or "server") * @returns True if retryable, false if not retryable, undefined if unknown */ isBedrockErrorRetryable(errorName, httpStatusCode, fault) { // Check specific AWS error types first switch (errorName) { // Retryable errors case 'ThrottlingException': case 'ServiceUnavailableException': case 'InternalServerException': case 'ServiceQuotaExceededException': return true; // Non-retryable errors case 'ValidationException': case 'AccessDeniedException': case 'ResourceNotFoundException': case 'ConflictException': case 'ResourceInUseException': case 'TooManyTagsException': return false; } // If we have HTTP status code, use it if (httpStatusCode !== undefined) { if (httpStatusCode === 429 || httpStatusCode === 408) return true; // Rate limit, timeout if (httpStatusCode === 529) return true; // Overloaded if (httpStatusCode >= 500 && httpStatusCode < 600) return true; // Server errors if (httpStatusCode >= 400 && httpStatusCode < 500) return false; // Client errors } // Fall back to fault type if (fault === 'server') return true; if (fault === 'client') return false; // Unknown error type - let consumer decide retry strategy return undefined; } getExtractedExecution(result, _prompt, options) { let resultText = ""; let reasoning = ""; if (result.output?.message?.content) { for (const content of result.output.message.content) { // Get text output if (content.text) { resultText += content.text; } else if (content.reasoningContent) { // Extract reasoning content if include_thoughts is true, or if it's a // reasoning-only model (e.g. DeepSeek R1) that returns no text blocks const claudeOptions = options?.model_options; const isReasoningModel = options?.model?.includes('deepseek') && options?.model?.includes('r1'); if (claudeOptions?.include_thoughts || isReasoningModel) { if (content.reasoningContent.reasoningText) { reasoning += content.reasoningContent.reasoningText.text; } else if (content.reasoningContent.redactedContent) { // Handle redacted thinking content const redactedData = new TextDecoder().decode(content.reasoningContent.redactedContent); reasoning += `[Redacted thinking: ${redactedData}]`; } } else { this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false"); } } else { // Get content block type const type = Object.keys(content).find(key => key !== '$unknown' && content[key] !== undefined); this.logger.info({ type }, "[Bedrock] Unsupported content response type:"); } } // Add spacing if we have reasoning content if (reasoning) { reasoning += '\n\n'; } } const completionResult = { result: reasoning + resultText ? [{ type: "text", value: reasoning + resultText }] : [], token_usage: { // Bedrock's inputTokens already excludes cache-read tokens, // so prompt_new is inputTokens directly (no subtraction needed). // prompt is the total including cached + cache_write for consistency // with the Vertex Claude driver. prompt_new: result.usage?.inputTokens, prompt: result.usage ? (result.usage.inputTokens ?? 0) + (result.usage.cacheReadInputTokens ?? 0) + (result.usage.cacheWriteInputTokens ?? 0) : undefined, result: result.usage?.outputTokens, total: result.usage?.totalTokens, prompt_cached: result.usage?.cacheReadInputTokens ?? undefined, prompt_cache_write: result.usage?.cacheWriteInputTokens ?? undefined, }, finish_reason: converseFinishReason(result.stopReason), }; return completionResult; } ; getExtractedStream(result, _prompt, options, streamingToolBlocks) { let output = ""; let reasoning = ""; let stop_reason = ""; let token_usage; let tool_use; // Check if we should include thoughts (always true for reasoning-only models like DeepSeek R1) const isReasoningModel = options?.model?.includes('deepseek') && options?.model?.includes('r1'); const shouldIncludeThoughts = isReasoningModel || (options && options.model_options?.include_thoughts); // Handle content block start events (for reasoning blocks and tool use) if (result.contentBlockStart) { if (result.contentBlockStart.start && 'toolUse' in result.contentBlockStart.start && result.contentBlockStart.start.toolUse) { // Register new tool call block and emit an initial chunk so the accumulator can track it by id const toolUseStart = result.contentBlockStart.start.toolUse; const blockIndex = result.contentBlockStart.contentBlockIndex ?? -1; const id = toolUseStart.toolUseId ?? ''; const name = toolUseStart.name ?? ''; streamingToolBlocks?.set(blockIndex, { id, name }); tool_use = [{ id, tool_name: name, tool_input: '' }]; } else if (result.contentBlockStart.start && 'reasoningContent' in result.contentBlockStart.start && shouldIncludeThoughts) { // Handle redacted content at block start const reasoningStart = result.contentBlockStart.start; if (reasoningStart.reasoningContent?.redactedContent) { const redactedData = new TextDecoder().decode(reasoningStart.reasoningContent.redactedContent); reasoning = `[Redacted thinking: ${redactedData}]`; } } } // Handle content block deltas (text, reasoning, and tool use) if (result.contentBlockDelta) { const delta = result.contentBlockDelta.delta; if (delta?.toolUse) { // Emit tool input chunk; the accumulator in DefaultCompletionStream concatenates these strings const blockIndex = result.contentBlockDelta.contentBlockIndex ?? -1; const toolBlock = streamingToolBlocks?.get(blockIndex); if (toolBlock && delta.toolUse.input !== undefined) { tool_use = [{ id: toolBlock.id, tool_name: '', tool_input: delta.toolUse.input }]; } } else if (delta?.text) { output = delta.text; } else if (delta?.reasoningContent && shouldIncludeThoughts) { if (delta.reasoningContent.text) { reasoning = delta.reasoningContent.text; } else if (delta.reasoningContent.redactedContent) { const redactedData = new TextDecoder().decode(delta.reasoningContent.redactedContent); reasoning = `[Redacted thinking: ${redactedData}]`; } else if (delta.reasoningContent.signature) { // Handle signature updates for reasoning content - end of thinking reasoning = "\n\n"; // Putting logging here so it only triggers once. this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false"); } } else if (delta) { // Get content block type const type = Object.keys(delta).find(key => key !== '$unknown' && delta[key] !== undefined); this.logger.info({ type }, "[Bedrock] Unsupported content response type:"); } } // Handle content block stop events if (result.contentBlockStop) { // Clean up tool block tracking entry const blockIndex = result.contentBlockStop.contentBlockIndex ?? -1; streamingToolBlocks?.delete(blockIndex); // Add minimal spacing for reasoning blocks if not already present if (reasoning && !reasoning.endsWith('\n\n') && shouldIncludeThoughts) { reasoning += '\n\n'; } } if (result.messageStop) { stop_reason = result.messageStop.stopReason ?? ""; } if (result.metadata) { token_usage = { prompt_new: result.metadata.usage?.inputTokens, prompt: result.metadata.usage ? (result.metadata.usage.inputTokens ?? 0) + (result.metadata.usage.cacheReadInputTokens ?? 0) + (result.metadata.usage.cacheWriteInputTokens ?? 0) : undefined, result: result.metadata.usage?.outputTokens, total: result.metadata.usage?.totalTokens, prompt_cached: result.metadata.usage?.cacheReadInputTokens ?? undefined, prompt_cache_write: result.metadata.usage?.cacheWriteInputTokens ?? undefined, }; } const completionResult = { result: reasoning + output ? [{ type: "text", value: reasoning + output }] : [], token_usage: token_usage, finish_reason: converseFinishReason(stop_reason), tool_use, }; return completionResult; } ; extractRegion(modelString, defaultRegion) { // Match region in full ARN pattern const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/); if (arnMatch) { return arnMatch[1]; } // Match common AWS regions directly in string const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/); if (regionMatch) { return regionMatch[0]; } return defaultRegion; } async getCanStream(model, type) { let canStream = false; let error = null; const region = this.extractRegion(model, this.options.region); if (type === BedrockModelType.FoundationModel || type === BedrockModelType.Unknown) { try { const response = await this.getService(region).getFoundationModel({ modelIdentifier: model }); canStream = response.modelDetails?.responseStreamingSupported ?? false; return canStream; } catch (e) { error = e; } } if (type === BedrockModelType.InferenceProfile || type === BedrockModelType.Unknown) { try { const response = await this.getService(region).getInferenceProfile({ inferenceProfileIdentifier: model }); canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel); return canStream; } catch (e) { error = e; } } if (type === BedrockModelType.CustomModel || type === BedrockModelType.Unknown) { try { const response = await this.getService(region).getCustomModel({ modelIdentifier: model }); canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel); return canStream; } catch (e) { error = e; } } if (error) { console.warn("Error on canStream check for model: " + model + " region detected: " + region, error); } return canStream; } async canStream(options) { // // TwelveLabs Pegasus supports streaming according to the documentation // if (options.model.includes("twelvelabs.pegasus")) { // return true; // } let canStream = supportStreamingCache.get(options.model); if (canStream == null) { let type = BedrockModelType.Unknown; if (options.model.includes("foundation-model")) { type = BedrockModelType.FoundationModel; } else if (options.model.includes("inference-profile")) { type = BedrockModelType.InferenceProfile; } else if (options.model.includes("custom-model")) { type = BedrockModelType.CustomModel; } canStream = await this.getCanStream(options.model, type); supportStreamingCache.set(options.model, canStream); } return canStream; } /** * Build conversation context after streaming completion. * Reconstructs the assistant message from accumulated results and applies stripping. */ buildStreamingConversation(prompt, result, toolUse, options) { // Only handle ConverseRequest prompts (not NovaMessagesPrompt or TwelvelabsPegasusRequest) if (options.model.includes("canvas") || options.model.includes("twelvelabs.pegasus")) { return undefined; } const conversePrompt = prompt; const completionResults = result; // Convert accumulated results to text content for assistant message const textContent = completionResults .map(r => { switch (r.type) { case 'text': return r.value; case 'json': return typeof r.value === 'string' ? r.value : JSON.stringify(r.value); case 'image': // Skip images in conversation - they're in the result return ''; default: return String(r.value || ''); } }) .join(''); // Deserialize any base64-encoded binary data back to Uint8Array const incomingConversation = deserializeBinaryFromStorage(options.conversation); // Start with the conversation from options combined with the prompt let conversation = updateConversation(incomingConversation, conversePrompt); // Build assistant message content const messageContent = []; if (textContent) { messageContent.push({ text: textContent }); } // Add tool use blocks if present if (toolUse && toolUse.length > 0) { for (const tool of toolUse) { messageContent.push({ toolUse: { toolUseId: tool.id, name: tool.tool_name, input: tool.tool_input, } }); } } // Add assistant message const assistantMessage = { messages: [{ content: messageContent.length > 0 ? messageContent : [{ text: '' }], role: "assistant" }], modelId: conversePrompt.modelId, }; conversation = updateConversation(conversation, assistantMessage); // Increment turn counter conversation = incrementConversationTurn(conversation); // Apply stripping based on options const currentTurn = getConversationMeta(conversation).turnNumber; const stripOptions = { keepForTurns: options.stripImagesAfterTurns ?? Infinity, currentTurn, textMaxTokens: options.stripTextMaxTokens }; let processedConversation = stripBinaryFromConversation(conversation, stripOptions); processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions); processedConversation = stripHeartbeatsFromConversation(processedConversation, { keepForTurns: options.stripHeartbeatsAfterTurns ?? 1, currentTurn, }); return processedConversation; } async requestTextCompletion(prompt, options) { // Handle Twelvelabs Pegasus models if (options.model.includes("twelvelabs.pegasus")) { return this.requestTwelvelabsPegasusCompletion(prompt, options); } // Handle other Bedrock models that use Converse API const conversePrompt = prompt; // Deserialize any base64-encoded binary data back to Uint8Array before API call const incomingConversation = deserializeBinaryFromStorage(options.conversation); let conversation = updateConversation(incomingConversation, conversePrompt); const payload = this.preparePayload(conversation, options); const executor = this.getExecutor(); const res = await executor.converse({ ...payload, }); // Strip reasoningContent from assistant messages before storing in conversation // (DeepSeek R1 returns reasoning blocks but rejects them in subsequent user turns) const assistantMsg = res.output?.message ?? { content: [{ text: "" }], role: "assistant" }; if (assistantMsg.content) { assistantMsg.content = assistantMsg.content.filter((c) => !c.reasoningContent); } conversation = updateConversation(conversation, { messages: [assistantMsg], modelId: conversePrompt.modelId, }); // Increment turn counter for deferred stripping conversation = incrementConversationTurn(conversation); let tool_use = undefined; //Get tool requests, we check tool use regardless of finish reason, as you can hit length and still get a valid response. tool_use = res.output?.message?.content?.reduce((tools, c) => { if (c.toolUse) { tools.push({ tool_name: c.toolUse.name ?? "", tool_input: c.toolUse.input, id: c.toolUse.toolUseId ?? "", }); } return tools; }, []); //If no tools were used, set to undefined if (tool_use && tool_use.length === 0) { tool_use = undefined; } // Strip/serialize binary data based on options.stripImagesAfterTurns const currentTurn = getConversationMeta(conversation).turnNumber; const stripOptions = { keepForTurns: options.stripImagesAfterTurns ?? Infinity, currentTurn, textMaxTokens: options.stripTextMaxTokens }; let processedConversation = stripBinaryFromConversation(conversation, stripOptions); // Truncate large text content if configured processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions); // Strip old heartbeat status messages processedConversation = stripHeartbeatsFromConversation(processedConversation, { keepForTurns: options.stripHeartbeatsAfterTurns ?? 1, currentTurn, }); const completion = { ...this.getExtractedExecution(res, conversePrompt, options), original_response: options.include_original_response ? res : undefined, conversation: processedConversation, tool_use: tool_use, }; return completion; } async requestTwelvelabsPegasusCompletion(prompt, options) { const executor = this.getExecutor(); const res = await executor.invokeModel({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(prompt), }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const result = JSON.parse(body); // Extract the response according to TwelveLabs Pegasus format let finishReason; switch (result.finishReason) { case "stop": finishReason = "stop"; break; case "length": finishReason = "length"; break; default: finishReason = result.finishReason; } return { result: result.message ? [{ type: "text", value: result.message }] : [], finish_reason: finishReason, original_response: options.include_original_response ? result : undefined, }; } async requestTwelvelabsPegasusCompletionStream(prompt, options) { const executor = this.getExecutor(); const res = await executor.invokeModelWithResponseStream({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(prompt), }); if (!res.body) { throw new Error("[Bedrock] Stream not found in response"); } return transformAsyncIterator(res.body, (chunk) => { if (chunk.chunk?.bytes) { const decoder = new TextDecoder(); const body = decoder.decode(chunk.chunk.bytes); try { const result = JSON.parse(body); // Extract streaming response according to TwelveLabs Pegasus format let finishReason; if (result.finishReason) { switch (result.finishReason) { case "stop": finishReason = "stop"; break; case "length": finishReason = "length"; break; default: finishReason = result.finishReason; } } return { result: result.delta || result.message ? [{ type: "text", value: result.delta || result.message || "" }] : [], finish_reason: finishReason, }; } catch (error) { // If JSON parsing fails, return empty chunk return { result: [], }; } } return { result: [], }; }); } async requestTextCompletionStream(prompt, options) { // Handle Twelvelabs Pegasus models if (options.model.includes("twelvelabs.pegasus")) { return this.requestTwelvelabsPegasusCompletionStream(prompt, options); } // Handle other Bedrock models that use Converse API const conversePrompt = prompt; // Include conversation history (same as non-streaming) // Deserialize any base64-encoded binary data back to Uint8Array before API call const incomingConversation = deserializeBinaryFromStorage(options.conversation); const conversation = updateConversation(incomingConversation, conversePrompt); const payload = this.preparePayload(conversation, options); const executor = this.getExecutor(); return executor.converseStream({ ...payload, }).then((res) => { const stream = res.stream; if (!stream) { throw new Error("[Bedrock] Stream not found in response"); } const streamingToolBlocks = new Map(); return transformAsyncIterator(stream, (streamSegment) => { return this.getExtractedStream(streamSegment, conversePrompt, options, streamingToolBlocks); }); }).catch((err) => { this.logger.error({ error: err }, "[Bedrock] Failed to stream"); throw err; }); } preparePayload(prompt, options) { const model_options = options.model_options ?? { _option_id: "text-fallback" }; let additionalField = {}; let supportsJSONPrefill = false; // Resolve thinking, effort, and sampling restrictions using shared Claude helper const claudeThinking = resolveClaudeThinking(options.model, options.model_options); const hasSamplingRestriction = claudeThinking.hasSamplingRestriction; if (options.model.includes("amazon")) { supportsJSONPrefill = true; //Titan models also exists but does not support any additional options if (options.model.includes("nova")) { additionalField = { inferenceConfig: { topK: model_options.top_k } }; } } else if (options.model.includes("claude")) { const claude_options = model_options; // Thinking is active when extended (budget set) or adaptive (effort set) thinking is enabled. // JSON prefill is incompatible with active thinking. const thinkingActive = claudeThinking.thinking != null && claudeThinking.thinking.type !== "disabled"; supportsJSONPrefill = !thinkingActive; // Claude 3.7+ supports thinking — use shared helper for reasoning_config if (claudeThinking.supportsThinking) { if (claudeThinking.thinking) { additionalField = { ...additionalField, reasoning_config: claudeThinking.thinking, }; } // For Claude 3.7 with extended thinking + high output, add beta header if (claudeThinking.thinking?.type === "enabled" && options.model.includes("claude-3-7-sonnet") && ((claude_options.max_tokens ?? 0) > 64000 || (claude_options.thinking_budget_tokens ?? 0) > 64000)) { additionalField = { ...additionalField, anthropic_beta: ["output-128k-2025-02-19"] }; } } // Add effort parameter via output_config (Opus 4.5+, Sonnet 4.6+, all 4.7+) if (claudeThinking.outputConfig) { additionalField = { ...additionalField, output_config: claudeThinking.outputConfig }; } // Claude 4.6 and later versions don't support JSON prefill if (isClaudeVersionGTE(options.model, 4, 6)) { supportsJSONPrefill = false; } // Needs max_tokens to be set if (!model_options.max_tokens) { model_options.max_tokens = maxTokenFallbackClaude(options); } // Only models without sampling restrictions support top_k if (!hasSamplingRestriction) { additionalField = { ...additionalField, top_k: model_options.top_k }; } } else if (options.model.includes("meta")) { //LLaMA models support no additional options } else if (options.model.includes("mistral")) { //7B instruct and 8x7B instruct if (options.model.includes("7b")) { additionalField = { top_k: model_options.top_k }; //Does not support system messages if (prompt.system && prompt.system?.length !== 0) { prompt.messages?.push(converseSystemToMessages(prompt.system)); prompt.system = undefined; prompt.messages = converseConcatMessages(prompt.messages); } } else { //Other models such as Mistral Small,Large and Large 2 //Support no additional fields. } } else if (options.model.includes("ai21")) { //Jamba models support no additional options //Jurassic 2 models do. if (options.model.includes("j2")) { additionalField = { presencePenalty: { scale: model_options.presence_penalty }, frequencyPenalty: { scale: model_options.frequency_penalty }, }; //Does not support system messages if (prompt.system && prompt.system?.length !== 0) { prompt.messages?.push(converseSystemToMessages(prompt.system)); prompt.system = undefined; prompt.messages = converseConcatMessages(prompt.messages); } } } else if (options.model.includes("cohere.command")) { // If last message is "```json", remove it. //Command R and R plus if (options.model.includes("cohere.command-r")) { additionalField = { k: model_options.top_k, frequency_penalty: model_options.frequency_penalty, presence_penalty: model_options.presence_penalty, }; } else { // Command non-R additionalField = { k: model_options.top_k }; //Does not support system messages if (prompt.system && prompt.system?.length !== 0) { prompt.messages?.push(converseSystemToMessages(prompt.system)); prompt.system = undefined; prompt.messages = converseConcatMessages(prompt.messages); } } } else if (options.model.includes("palmyra")) { const palmyraOptions = model_options; additionalField = { seed: palmyraOptions?.seed, presence_penalty: palmyraOptions?.presence_penalty, frequency_penalty: palmyraOptions?.frequency_penalty, min_tokens: palmyraOptions?.min_tokens, }; } else if (options.model.includes("deepseek")) { // DeepSeek models: no additional options, no stopSequences, only one of temperature/top_p model_options.stop_sequence = undefined; model_options.top_p = undefined; } else if (options.model.includes("gpt-oss")) { const gptOssOptions = model_options; additionalField = { reasoning_effort: gptOssOptions?.reasoning_effort, }; } //If last message is "```json", add corresponding ``` as a stop sequence. if (prompt.messages && prompt.messages.length > 0) { if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") { const stopSeq = model_options.stop_sequence; if (!stopSeq) { model_options.stop_sequence = ["```"]; } else if (!stopSeq.includes("```")) { stopSeq.push("```"); model_options.stop_sequence = stopSeq; } } } const tool_defs = getToolDefinitions(options.tools); // Use prefill when there is a schema and tools are not being used if (supportsJSONPrefill && options.result_schema && !tool_defs) { prompt.messages = converseJSONprefill(prompt.messages); } // Clean undefined values from additionalField since AWS Bedrock requires valid JSON // and will throw an exception for unrecognized parameters const cleanedAdditionalFields = removeUndefinedValues(additionalField); // Models with sampling parameter restrictions don't support temperature/top_p - exclude them from inference config const cleanedModelOptions = removeUndefinedValues({ maxTokens: model_options.max_tokens, ...(hasSamplingRestriction ? {} : { temperature: model_options.temperature, topP: model_options.temperature != null ? undefined : model_options.top_p, }), stopSequences: model_options.stop_sequence, }); //Construct the final request payload // We only add fields that are defined to avoid AWS errors const request = { modelId: options.model, }; if (prompt.messages) { request.messages = prompt.messages; } if (prompt.system) { request.system = prompt.system; } if (Object.keys(cleanedModelOptions).length > 0) { request.inferenceConfig = cleanedModelOptions; } if (Object.keys(cleanedAdditionalFields).length > 0) { request.additionalModelRequestFields = cleanedAdditionalFields; } if (tool_defs?.length) { request.toolConfig = { tools: tool_defs, }; } else if (request.messages && messagesContainToolBlocks(request.messages)) { // Bedrock requires toolConfig when conversation contains toolUse/toolResult blocks. // When no tools are provided (e.g. checkpoint summary calls), convert tool blocks // to text representations so the conversation data is preserved while satisfying // Bedrock's API requirements without making tools callable. request.messages = convertToolBlocksToText(request.messages); } // Prompt caching: use three breakpoints so stable system blocks, tool definitions, // and the conversation history prefix can all be reused across Claude turns. if (options.model.includes('claude')) { // Always strip stale markers from prior turns if (request.messages) { request.messages = stripClaudeCachePoints(request.messages); } request.system = stripClaudeCachePointsFromSystem(request.system); if (request.toolConfig?.tools) { request.toolConfig = { ...request.toolConfig, tools: stripClaudeCachePointsFromTools(request.toolConfig.tools), }; } const claudeOptions = model_options; const cacheEnabled = claudeOptions?.cache_enabled === true; if (cacheEnabled) { const cacheTtl = claudeOptions?.cache_ttl; const cachePointBlock = { type: 'default', ...(cacheTtl && { ttl: cacheTtl }) }; if (request.system && request.system.length > 0) { request.system = [...request.system, { cachePoint: cachePointBlock }]; } if (request.toolConfig?.tools && request.toolConfig.tools.length > 0) { request.toolConfig.tools = [ ...request.toolConfig.tools, { cachePoint: cachePointBlock }, ]; } if (request.messages && request.messages.length >= 4) { const pivotMsg = request.messages[request.messages.length - 2]; if (pivotMsg.content && Array.isArray(pivotMsg.content) && pivotMsg.content.length > 0) { pivotMsg.content = [...pivotMsg.content, { cachePoint: cachePointBlock }]; } } } } return request; } isImageModel(model) { return model.includes("titan-image") || model.includes("stable-diffusion") || model.includes("nova-canvas"); } async requestImageGeneration(prompt, options) { if (options.model_options?._option_id !== undefined && options.model_options?._option_id !== "bedrock-nova-canvas") { this.logger.debug({ options: options.model_options }, "Unexpected option id"); } const model_options = options.model_options; const executor = this.getExecutor(); const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE; this.logger.info("Task type: " + taskType); if (typeof prompt === "string") { throw new Error("Bad prompt format"); } const payload = await formatNovaImageGenerationPayload(taskType, prompt, options); const res = await executor.invokeModel({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(payload), }, { requestTimeout: 60000 * 5 }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const bedrockResult = JSON.parse(body); return { error: bedrockResult.error, result: bedrockResult.images.map((image) => ({ type: "image", value: image })) }; } async startTraining(dataset, options) { //convert options.params to Record<string, string> const params = {}; for (const [key, value] of Object.entries(options.params || {})) { params[key] = String(value); } if (!this.options.training_bucket) { throw new Error("Training cannot nbe used since the 'training_bucket' property was not specified in driver options"); } const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials }); const stream = await dataset.getStream(); const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name); const service = this.getService(); const response = await service.send(new CreateModelCustomizationJobCommand({ jobName: options.name + "-job", customModelName: options.name, roleArn: this.options.training_role_arn || undefined, baseModelIdentifier: options.model, clientRequestToken: "llumiverse-" + Date.now(), trainingDataConfig: { s3Uri: `s3://${upload.Bucket}/${upload.Key}`, }, outputDataConfig: undefined, hyperParameters: params, //TODO not supported? //customizationType: "FINE_TUNING", })); const job = await service.send(new GetModelCustomizationJobCommand({ jobIdentifier: response.jobArn })); return jobInfo(job, response.jobArn); } async cancelTraining(jobId) { const service = this.getService(); await service.send(new StopModelCustomizationJobCommand({ jobIdentifier: jobId })); const job = await service.send(new GetModelCustomizationJobCommand({ jobIdentifier: jobId })); return jobInfo(job, jobId); } async getTrainingJob(jobId) { const service = this.getService(); const job = await service.send(new GetModelCustomizationJobCommand({ jobIdentifier: jobId })); return jobInfo(job, jobId); } // ===================== management API ================== async validateConnection() { const service = this.getService(); this.logger.debug("[Bedrock] validating connection", service.config.credentials.name); //return true as if the client has been initialized, it means the connection is valid return true; } async listTrainableModels() { this.logger.debug("[Bedrock] listing trainable models"); return this._listModels(m => m.customizationsSupported ? m.customizationsSupported.includes("FINE_TUNING") : false); } async listModels() { this.logger.debug("[Bedrock] listing models"); // exclude trainable models since they are not executable // exclude embedding models, not to be used for typical completions. const filter = (m) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false; return this._listModels(filter); } async _listModels(foundationFilter) { const service = this.getService(); const [foundationModelsList, customModelsList, inferenceProfilesList] = await Promise.all([ service.listFoundationModels({}).catch(() => { this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions."); return undefined; }), service.listCustomModels({}).catch(() => { this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions."); return undefined; }), service.listInferenceProfiles({}).catch(() => { this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions."); return undefined; }), ]); if (!foundationModelsList?.modelSummaries) { throw new Error("Foundation models not found"); } let foundationModels = foundationModelsList.modelSummaries || []; if (foundationFilter) { foundationModels = foundationModels.filter(foundationFilter); } const supportedPublishers = ["amazon", "anthropic", "cohere", "ai21", "mistral", "meta", "deepseek", "writer", "openai", "twelvelabs", "qwen"]; const unsupportedModelsByPublisher = { amazon: ["titan-image-generator", "nova-reel", "nova-sonic", "rerank"], anthr