@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
1,046 lines • 68.7 kB
JavaScript
import { Bedrock, CreateModelCustomizationJobCommand, GetModelCustomizationJobCommand, ModelCustomizationJobStatus, ModelModality, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock";
import { BedrockRuntime } from "@aws-sdk/client-bedrock-runtime";
import { S3Client } from "@aws-sdk/client-s3";
import { AbstractDriver, deserializeBinaryFromStorage, getConversationMeta, getMaxTokensLimitBedrock, getModelCapabilities, incrementConversationTurn, isClaudeVersionGTE, LlumiverseError, modelModalitiesToArray, stripBinaryFromConversation, stripHeartbeatsFromConversation, TrainingJobStatus, truncateLargeTextInConversation } from "@llumiverse/core";
import { transformAsyncIterator } from "@llumiverse/core/async";
import { formatNovaPrompt } from "@llumiverse/core/formatters";
import { LRUCache } from "mnemonist";
import { resolveClaudeThinking } from "../shared/claude-thinking.js";
import { converseConcatMessages, converseJSONprefill, converseSystemToMessages, formatConversePrompt } from "./converse.js";
import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js";
import { forceUploadFile } from "./s3.js";
import { formatTwelvelabsPegasusPrompt } from "./twelvelabs.js";
const supportStreamingCache = new LRUCache(4096);
var BedrockModelType;
(function (BedrockModelType) {
BedrockModelType["FoundationModel"] = "foundation-model";
BedrockModelType["InferenceProfile"] = "inference-profile";
BedrockModelType["CustomModel"] = "custom-model";
BedrockModelType["Unknown"] = "unknown";
})(BedrockModelType || (BedrockModelType = {}));
;
function converseFinishReason(reason) {
//Possible values:
//end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
if (!reason)
return undefined;
switch (reason) {
case 'end_turn': return "stop";
case 'max_tokens': return "length";
default: return reason;
}
}
//Used to get a max_token value when not specified in the model options. Claude requires it to be set.
function maxTokenFallbackClaude(option) {
const modelOptions = option.model_options;
if (modelOptions && typeof modelOptions.max_tokens === "number") {
return modelOptions.max_tokens;
}
else {
let maxSupportedTokens = getMaxTokensLimitBedrock(option.model) ?? 8192; // Should always return a number for claude, 8192 is to satisfy the TypeScript type checker;
// Fallback to the default max tokens limit for the model
if (option.model.includes('claude-3-7-sonnet') && (modelOptions?.thinking_budget_tokens ?? 0) < 48000) {
maxSupportedTokens = 64000; // Claude 3.7 can go up to 128k with a beta header, but when no max tokens is specified, we default to 64k.
}
return maxSupportedTokens;
}
}
export class BedrockDriver extends AbstractDriver {
static PROVIDER = "bedrock";
provider = BedrockDriver.PROVIDER;
_executor;
_service;
_service_region;
constructor(options) {
super(options);
if (!options.region) {
throw new Error("No region found. Set the region in the environment's endpoint URL.");
}
}
getExecutor() {
if (!this._executor) {
this._executor = new BedrockRuntime({
region: this.options.region,
credentials: this.options.credentials,
});
}
return this._executor;
}
getService(region = this.options.region) {
if (!this._service || this._service_region !== region) {
this._service = new Bedrock({
region: region,
credentials: this.options.credentials,
});
this._service_region = region;
}
return this._service;
}
async formatPrompt(segments, opts) {
if (opts.model.includes("canvas")) {
return await formatNovaPrompt(segments, opts.result_schema);
}
if (opts.model.includes("twelvelabs.pegasus")) {
return await formatTwelvelabsPegasusPrompt(segments, opts);
}
return await formatConversePrompt(segments, opts);
}
/**
* Format AWS Bedrock errors into LlumiverseError with proper status codes and retryability.
*
* AWS SDK errors provide:
* - error.name: The exception type (e.g., "ThrottlingException")
* - error.$metadata.httpStatusCode: The HTTP status code
* - error.$metadata.requestId: The AWS request ID for tracking
* - error.$fault: "client" or "server" indicating error category
*
* @param error - The AWS SDK error
* @param context - Context about where the error occurred
* @returns A standardized LlumiverseError
*/
formatLlumiverseError(error, context) {
// Check if it's an AWS SDK error with $metadata
const awsError = error;
const hasMetadata = awsError?.$metadata !== undefined;
if (!hasMetadata) {
// Not an AWS SDK error, use default handling
return super.formatLlumiverseError(error, context);
}
// Extract AWS-specific fields
const errorName = awsError.name || 'UnknownError';
const httpStatusCode = awsError.$metadata?.httpStatusCode;
const requestId = awsError.$metadata?.requestId;
const fault = awsError.$fault; // "client" or "server"
// Extract error message - handle both Error instances and plain objects
let message;
if (error instanceof Error) {
message = error.message;
}
else if (typeof awsError.message === 'string') {
message = awsError.message;
}
else {
message = String(error);
}
// Build user-facing message with error name and status code
let userMessage = message;
// Include status code in message if available (for end-user visibility)
if (httpStatusCode) {
userMessage = `[${httpStatusCode}] ${userMessage}`;
}
// Prefix with error name if it's meaningful (not just "Error")
if (errorName && errorName !== 'Error' && errorName !== 'UnknownError') {
userMessage = `${errorName}: ${userMessage}`;
}
// Add request ID if available (useful for AWS support)
if (requestId) {
userMessage += ` (Request ID: ${requestId})`;
}
// Determine retryability based on AWS error types
const retryable = this.isBedrockErrorRetryable(errorName, httpStatusCode, fault);
return new LlumiverseError(`[${this.provider}] ${userMessage}`, retryable, context, error, httpStatusCode, // Only set code if we have numeric status code
errorName // Preserve AWS error name
);
}
/**
* Determine if a Bedrock error is retryable based on error type and status.
*
* Retryable errors:
* - ThrottlingException: Rate limit exceeded, retry with backoff
* - ServiceUnavailableException: Service temporarily down
* - InternalServerException: Server-side error
* - ServiceQuotaExceededException: Quota exhausted, may recover
* - 5xx status codes: Server errors
* - 429, 408 status codes: Rate limit, timeout
*
* Non-retryable errors:
* - ValidationException: Invalid request parameters
* - AccessDeniedException: Authentication/authorization failure
* - ResourceNotFoundException: Resource doesn't exist
* - ConflictException: Resource state conflict
* - ResourceInUseException: Resource locked by another operation
* - 4xx status codes (except 429, 408): Client errors
*
* @param errorName - The AWS error name (e.g., "ThrottlingException")
* @param httpStatusCode - The HTTP status code if available
* @param fault - The fault type ("client" or "server")
* @returns True if retryable, false if not retryable, undefined if unknown
*/
isBedrockErrorRetryable(errorName, httpStatusCode, fault) {
// Check specific AWS error types first
switch (errorName) {
// Retryable errors
case 'ThrottlingException':
case 'ServiceUnavailableException':
case 'InternalServerException':
case 'ServiceQuotaExceededException':
return true;
// Non-retryable errors
case 'ValidationException':
case 'AccessDeniedException':
case 'ResourceNotFoundException':
case 'ConflictException':
case 'ResourceInUseException':
case 'TooManyTagsException':
return false;
}
// If we have HTTP status code, use it
if (httpStatusCode !== undefined) {
if (httpStatusCode === 429 || httpStatusCode === 408)
return true; // Rate limit, timeout
if (httpStatusCode === 529)
return true; // Overloaded
if (httpStatusCode >= 500 && httpStatusCode < 600)
return true; // Server errors
if (httpStatusCode >= 400 && httpStatusCode < 500)
return false; // Client errors
}
// Fall back to fault type
if (fault === 'server')
return true;
if (fault === 'client')
return false;
// Unknown error type - let consumer decide retry strategy
return undefined;
}
getExtractedExecution(result, _prompt, options) {
let resultText = "";
let reasoning = "";
if (result.output?.message?.content) {
for (const content of result.output.message.content) {
// Get text output
if (content.text) {
resultText += content.text;
}
else if (content.reasoningContent) {
// Extract reasoning content if include_thoughts is true, or if it's a
// reasoning-only model (e.g. DeepSeek R1) that returns no text blocks
const claudeOptions = options?.model_options;
const isReasoningModel = options?.model?.includes('deepseek') && options?.model?.includes('r1');
if (claudeOptions?.include_thoughts || isReasoningModel) {
if (content.reasoningContent.reasoningText) {
reasoning += content.reasoningContent.reasoningText.text;
}
else if (content.reasoningContent.redactedContent) {
// Handle redacted thinking content
const redactedData = new TextDecoder().decode(content.reasoningContent.redactedContent);
reasoning += `[Redacted thinking: ${redactedData}]`;
}
}
else {
this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false");
}
}
else {
// Get content block type
const type = Object.keys(content).find(key => key !== '$unknown' && content[key] !== undefined);
this.logger.info({ type }, "[Bedrock] Unsupported content response type:");
}
}
// Add spacing if we have reasoning content
if (reasoning) {
reasoning += '\n\n';
}
}
const completionResult = {
result: reasoning + resultText ? [{ type: "text", value: reasoning + resultText }] : [],
token_usage: {
// Bedrock's inputTokens already excludes cache-read tokens,
// so prompt_new is inputTokens directly (no subtraction needed).
// prompt is the total including cached + cache_write for consistency
// with the Vertex Claude driver.
prompt_new: result.usage?.inputTokens,
prompt: result.usage ? (result.usage.inputTokens ?? 0) + (result.usage.cacheReadInputTokens ?? 0) + (result.usage.cacheWriteInputTokens ?? 0) : undefined,
result: result.usage?.outputTokens,
total: result.usage?.totalTokens,
prompt_cached: result.usage?.cacheReadInputTokens ?? undefined,
prompt_cache_write: result.usage?.cacheWriteInputTokens ?? undefined,
},
finish_reason: converseFinishReason(result.stopReason),
};
return completionResult;
}
;
getExtractedStream(result, _prompt, options, streamingToolBlocks) {
let output = "";
let reasoning = "";
let stop_reason = "";
let token_usage;
let tool_use;
// Check if we should include thoughts (always true for reasoning-only models like DeepSeek R1)
const isReasoningModel = options?.model?.includes('deepseek') && options?.model?.includes('r1');
const shouldIncludeThoughts = isReasoningModel || (options && options.model_options?.include_thoughts);
// Handle content block start events (for reasoning blocks and tool use)
if (result.contentBlockStart) {
if (result.contentBlockStart.start && 'toolUse' in result.contentBlockStart.start && result.contentBlockStart.start.toolUse) {
// Register new tool call block and emit an initial chunk so the accumulator can track it by id
const toolUseStart = result.contentBlockStart.start.toolUse;
const blockIndex = result.contentBlockStart.contentBlockIndex ?? -1;
const id = toolUseStart.toolUseId ?? '';
const name = toolUseStart.name ?? '';
streamingToolBlocks?.set(blockIndex, { id, name });
tool_use = [{ id, tool_name: name, tool_input: '' }];
}
else if (result.contentBlockStart.start && 'reasoningContent' in result.contentBlockStart.start && shouldIncludeThoughts) {
// Handle redacted content at block start
const reasoningStart = result.contentBlockStart.start;
if (reasoningStart.reasoningContent?.redactedContent) {
const redactedData = new TextDecoder().decode(reasoningStart.reasoningContent.redactedContent);
reasoning = `[Redacted thinking: ${redactedData}]`;
}
}
}
// Handle content block deltas (text, reasoning, and tool use)
if (result.contentBlockDelta) {
const delta = result.contentBlockDelta.delta;
if (delta?.toolUse) {
// Emit tool input chunk; the accumulator in DefaultCompletionStream concatenates these strings
const blockIndex = result.contentBlockDelta.contentBlockIndex ?? -1;
const toolBlock = streamingToolBlocks?.get(blockIndex);
if (toolBlock && delta.toolUse.input !== undefined) {
tool_use = [{ id: toolBlock.id, tool_name: '', tool_input: delta.toolUse.input }];
}
}
else if (delta?.text) {
output = delta.text;
}
else if (delta?.reasoningContent && shouldIncludeThoughts) {
if (delta.reasoningContent.text) {
reasoning = delta.reasoningContent.text;
}
else if (delta.reasoningContent.redactedContent) {
const redactedData = new TextDecoder().decode(delta.reasoningContent.redactedContent);
reasoning = `[Redacted thinking: ${redactedData}]`;
}
else if (delta.reasoningContent.signature) {
// Handle signature updates for reasoning content - end of thinking
reasoning = "\n\n";
// Putting logging here so it only triggers once.
this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false");
}
}
else if (delta) {
// Get content block type
const type = Object.keys(delta).find(key => key !== '$unknown' && delta[key] !== undefined);
this.logger.info({ type }, "[Bedrock] Unsupported content response type:");
}
}
// Handle content block stop events
if (result.contentBlockStop) {
// Clean up tool block tracking entry
const blockIndex = result.contentBlockStop.contentBlockIndex ?? -1;
streamingToolBlocks?.delete(blockIndex);
// Add minimal spacing for reasoning blocks if not already present
if (reasoning && !reasoning.endsWith('\n\n') && shouldIncludeThoughts) {
reasoning += '\n\n';
}
}
if (result.messageStop) {
stop_reason = result.messageStop.stopReason ?? "";
}
if (result.metadata) {
token_usage = {
prompt_new: result.metadata.usage?.inputTokens,
prompt: result.metadata.usage ? (result.metadata.usage.inputTokens ?? 0) + (result.metadata.usage.cacheReadInputTokens ?? 0) + (result.metadata.usage.cacheWriteInputTokens ?? 0) : undefined,
result: result.metadata.usage?.outputTokens,
total: result.metadata.usage?.totalTokens,
prompt_cached: result.metadata.usage?.cacheReadInputTokens ?? undefined,
prompt_cache_write: result.metadata.usage?.cacheWriteInputTokens ?? undefined,
};
}
const completionResult = {
result: reasoning + output ? [{ type: "text", value: reasoning + output }] : [],
token_usage: token_usage,
finish_reason: converseFinishReason(stop_reason),
tool_use,
};
return completionResult;
}
;
extractRegion(modelString, defaultRegion) {
// Match region in full ARN pattern
const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/);
if (arnMatch) {
return arnMatch[1];
}
// Match common AWS regions directly in string
const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/);
if (regionMatch) {
return regionMatch[0];
}
return defaultRegion;
}
async getCanStream(model, type) {
let canStream = false;
let error = null;
const region = this.extractRegion(model, this.options.region);
if (type === BedrockModelType.FoundationModel || type === BedrockModelType.Unknown) {
try {
const response = await this.getService(region).getFoundationModel({
modelIdentifier: model
});
canStream = response.modelDetails?.responseStreamingSupported ?? false;
return canStream;
}
catch (e) {
error = e;
}
}
if (type === BedrockModelType.InferenceProfile || type === BedrockModelType.Unknown) {
try {
const response = await this.getService(region).getInferenceProfile({
inferenceProfileIdentifier: model
});
canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel);
return canStream;
}
catch (e) {
error = e;
}
}
if (type === BedrockModelType.CustomModel || type === BedrockModelType.Unknown) {
try {
const response = await this.getService(region).getCustomModel({
modelIdentifier: model
});
canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel);
return canStream;
}
catch (e) {
error = e;
}
}
if (error) {
console.warn("Error on canStream check for model: " + model + " region detected: " + region, error);
}
return canStream;
}
async canStream(options) {
// // TwelveLabs Pegasus supports streaming according to the documentation
// if (options.model.includes("twelvelabs.pegasus")) {
// return true;
// }
let canStream = supportStreamingCache.get(options.model);
if (canStream == null) {
let type = BedrockModelType.Unknown;
if (options.model.includes("foundation-model")) {
type = BedrockModelType.FoundationModel;
}
else if (options.model.includes("inference-profile")) {
type = BedrockModelType.InferenceProfile;
}
else if (options.model.includes("custom-model")) {
type = BedrockModelType.CustomModel;
}
canStream = await this.getCanStream(options.model, type);
supportStreamingCache.set(options.model, canStream);
}
return canStream;
}
/**
* Build conversation context after streaming completion.
* Reconstructs the assistant message from accumulated results and applies stripping.
*/
buildStreamingConversation(prompt, result, toolUse, options) {
// Only handle ConverseRequest prompts (not NovaMessagesPrompt or TwelvelabsPegasusRequest)
if (options.model.includes("canvas") || options.model.includes("twelvelabs.pegasus")) {
return undefined;
}
const conversePrompt = prompt;
const completionResults = result;
// Convert accumulated results to text content for assistant message
const textContent = completionResults
.map(r => {
switch (r.type) {
case 'text':
return r.value;
case 'json':
return typeof r.value === 'string' ? r.value : JSON.stringify(r.value);
case 'image':
// Skip images in conversation - they're in the result
return '';
default:
return String(r.value || '');
}
})
.join('');
// Deserialize any base64-encoded binary data back to Uint8Array
const incomingConversation = deserializeBinaryFromStorage(options.conversation);
// Start with the conversation from options combined with the prompt
let conversation = updateConversation(incomingConversation, conversePrompt);
// Build assistant message content
const messageContent = [];
if (textContent) {
messageContent.push({ text: textContent });
}
// Add tool use blocks if present
if (toolUse && toolUse.length > 0) {
for (const tool of toolUse) {
messageContent.push({
toolUse: {
toolUseId: tool.id,
name: tool.tool_name,
input: tool.tool_input,
}
});
}
}
// Add assistant message
const assistantMessage = {
messages: [{
content: messageContent.length > 0 ? messageContent : [{ text: '' }],
role: "assistant"
}],
modelId: conversePrompt.modelId,
};
conversation = updateConversation(conversation, assistantMessage);
// Increment turn counter
conversation = incrementConversationTurn(conversation);
// Apply stripping based on options
const currentTurn = getConversationMeta(conversation).turnNumber;
const stripOptions = {
keepForTurns: options.stripImagesAfterTurns ?? Infinity,
currentTurn,
textMaxTokens: options.stripTextMaxTokens
};
let processedConversation = stripBinaryFromConversation(conversation, stripOptions);
processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions);
processedConversation = stripHeartbeatsFromConversation(processedConversation, {
keepForTurns: options.stripHeartbeatsAfterTurns ?? 1,
currentTurn,
});
return processedConversation;
}
async requestTextCompletion(prompt, options) {
// Handle Twelvelabs Pegasus models
if (options.model.includes("twelvelabs.pegasus")) {
return this.requestTwelvelabsPegasusCompletion(prompt, options);
}
// Handle other Bedrock models that use Converse API
const conversePrompt = prompt;
// Deserialize any base64-encoded binary data back to Uint8Array before API call
const incomingConversation = deserializeBinaryFromStorage(options.conversation);
let conversation = updateConversation(incomingConversation, conversePrompt);
const payload = this.preparePayload(conversation, options);
const executor = this.getExecutor();
const res = await executor.converse({
...payload,
});
// Strip reasoningContent from assistant messages before storing in conversation
// (DeepSeek R1 returns reasoning blocks but rejects them in subsequent user turns)
const assistantMsg = res.output?.message ?? { content: [{ text: "" }], role: "assistant" };
if (assistantMsg.content) {
assistantMsg.content = assistantMsg.content.filter((c) => !c.reasoningContent);
}
conversation = updateConversation(conversation, {
messages: [assistantMsg],
modelId: conversePrompt.modelId,
});
// Increment turn counter for deferred stripping
conversation = incrementConversationTurn(conversation);
let tool_use = undefined;
//Get tool requests, we check tool use regardless of finish reason, as you can hit length and still get a valid response.
tool_use = res.output?.message?.content?.reduce((tools, c) => {
if (c.toolUse) {
tools.push({
tool_name: c.toolUse.name ?? "",
tool_input: c.toolUse.input,
id: c.toolUse.toolUseId ?? "",
});
}
return tools;
}, []);
//If no tools were used, set to undefined
if (tool_use && tool_use.length === 0) {
tool_use = undefined;
}
// Strip/serialize binary data based on options.stripImagesAfterTurns
const currentTurn = getConversationMeta(conversation).turnNumber;
const stripOptions = {
keepForTurns: options.stripImagesAfterTurns ?? Infinity,
currentTurn,
textMaxTokens: options.stripTextMaxTokens
};
let processedConversation = stripBinaryFromConversation(conversation, stripOptions);
// Truncate large text content if configured
processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions);
// Strip old heartbeat status messages
processedConversation = stripHeartbeatsFromConversation(processedConversation, {
keepForTurns: options.stripHeartbeatsAfterTurns ?? 1,
currentTurn,
});
const completion = {
...this.getExtractedExecution(res, conversePrompt, options),
original_response: options.include_original_response ? res : undefined,
conversation: processedConversation,
tool_use: tool_use,
};
return completion;
}
async requestTwelvelabsPegasusCompletion(prompt, options) {
const executor = this.getExecutor();
const res = await executor.invokeModel({
modelId: options.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(prompt),
});
const decoder = new TextDecoder();
const body = decoder.decode(res.body);
const result = JSON.parse(body);
// Extract the response according to TwelveLabs Pegasus format
let finishReason;
switch (result.finishReason) {
case "stop":
finishReason = "stop";
break;
case "length":
finishReason = "length";
break;
default:
finishReason = result.finishReason;
}
return {
result: result.message ? [{ type: "text", value: result.message }] : [],
finish_reason: finishReason,
original_response: options.include_original_response ? result : undefined,
};
}
async requestTwelvelabsPegasusCompletionStream(prompt, options) {
const executor = this.getExecutor();
const res = await executor.invokeModelWithResponseStream({
modelId: options.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(prompt),
});
if (!res.body) {
throw new Error("[Bedrock] Stream not found in response");
}
return transformAsyncIterator(res.body, (chunk) => {
if (chunk.chunk?.bytes) {
const decoder = new TextDecoder();
const body = decoder.decode(chunk.chunk.bytes);
try {
const result = JSON.parse(body);
// Extract streaming response according to TwelveLabs Pegasus format
let finishReason;
if (result.finishReason) {
switch (result.finishReason) {
case "stop":
finishReason = "stop";
break;
case "length":
finishReason = "length";
break;
default:
finishReason = result.finishReason;
}
}
return {
result: result.delta || result.message ? [{ type: "text", value: result.delta || result.message || "" }] : [],
finish_reason: finishReason,
};
}
catch (error) {
// If JSON parsing fails, return empty chunk
return {
result: [],
};
}
}
return {
result: [],
};
});
}
async requestTextCompletionStream(prompt, options) {
// Handle Twelvelabs Pegasus models
if (options.model.includes("twelvelabs.pegasus")) {
return this.requestTwelvelabsPegasusCompletionStream(prompt, options);
}
// Handle other Bedrock models that use Converse API
const conversePrompt = prompt;
// Include conversation history (same as non-streaming)
// Deserialize any base64-encoded binary data back to Uint8Array before API call
const incomingConversation = deserializeBinaryFromStorage(options.conversation);
const conversation = updateConversation(incomingConversation, conversePrompt);
const payload = this.preparePayload(conversation, options);
const executor = this.getExecutor();
return executor.converseStream({
...payload,
}).then((res) => {
const stream = res.stream;
if (!stream) {
throw new Error("[Bedrock] Stream not found in response");
}
const streamingToolBlocks = new Map();
return transformAsyncIterator(stream, (streamSegment) => {
return this.getExtractedStream(streamSegment, conversePrompt, options, streamingToolBlocks);
});
}).catch((err) => {
this.logger.error({ error: err }, "[Bedrock] Failed to stream");
throw err;
});
}
preparePayload(prompt, options) {
const model_options = options.model_options ?? { _option_id: "text-fallback" };
let additionalField = {};
let supportsJSONPrefill = false;
// Resolve thinking, effort, and sampling restrictions using shared Claude helper
const claudeThinking = resolveClaudeThinking(options.model, options.model_options);
const hasSamplingRestriction = claudeThinking.hasSamplingRestriction;
if (options.model.includes("amazon")) {
supportsJSONPrefill = true;
//Titan models also exists but does not support any additional options
if (options.model.includes("nova")) {
additionalField = { inferenceConfig: { topK: model_options.top_k } };
}
}
else if (options.model.includes("claude")) {
const claude_options = model_options;
// Thinking is active when extended (budget set) or adaptive (effort set) thinking is enabled.
// JSON prefill is incompatible with active thinking.
const thinkingActive = claudeThinking.thinking != null && claudeThinking.thinking.type !== "disabled";
supportsJSONPrefill = !thinkingActive;
// Claude 3.7+ supports thinking — use shared helper for reasoning_config
if (claudeThinking.supportsThinking) {
if (claudeThinking.thinking) {
additionalField = {
...additionalField,
reasoning_config: claudeThinking.thinking,
};
}
// For Claude 3.7 with extended thinking + high output, add beta header
if (claudeThinking.thinking?.type === "enabled" &&
options.model.includes("claude-3-7-sonnet") &&
((claude_options.max_tokens ?? 0) > 64000 || (claude_options.thinking_budget_tokens ?? 0) > 64000)) {
additionalField = {
...additionalField,
anthropic_beta: ["output-128k-2025-02-19"]
};
}
}
// Add effort parameter via output_config (Opus 4.5+, Sonnet 4.6+, all 4.7+)
if (claudeThinking.outputConfig) {
additionalField = {
...additionalField,
output_config: claudeThinking.outputConfig
};
}
// Claude 4.6 and later versions don't support JSON prefill
if (isClaudeVersionGTE(options.model, 4, 6)) {
supportsJSONPrefill = false;
}
// Needs max_tokens to be set
if (!model_options.max_tokens) {
model_options.max_tokens = maxTokenFallbackClaude(options);
}
// Only models without sampling restrictions support top_k
if (!hasSamplingRestriction) {
additionalField = { ...additionalField, top_k: model_options.top_k };
}
}
else if (options.model.includes("meta")) {
//LLaMA models support no additional options
}
else if (options.model.includes("mistral")) {
//7B instruct and 8x7B instruct
if (options.model.includes("7b")) {
additionalField = { top_k: model_options.top_k };
//Does not support system messages
if (prompt.system && prompt.system?.length !== 0) {
prompt.messages?.push(converseSystemToMessages(prompt.system));
prompt.system = undefined;
prompt.messages = converseConcatMessages(prompt.messages);
}
}
else {
//Other models such as Mistral Small,Large and Large 2
//Support no additional fields.
}
}
else if (options.model.includes("ai21")) {
//Jamba models support no additional options
//Jurassic 2 models do.
if (options.model.includes("j2")) {
additionalField = {
presencePenalty: { scale: model_options.presence_penalty },
frequencyPenalty: { scale: model_options.frequency_penalty },
};
//Does not support system messages
if (prompt.system && prompt.system?.length !== 0) {
prompt.messages?.push(converseSystemToMessages(prompt.system));
prompt.system = undefined;
prompt.messages = converseConcatMessages(prompt.messages);
}
}
}
else if (options.model.includes("cohere.command")) {
// If last message is "```json", remove it.
//Command R and R plus
if (options.model.includes("cohere.command-r")) {
additionalField = {
k: model_options.top_k,
frequency_penalty: model_options.frequency_penalty,
presence_penalty: model_options.presence_penalty,
};
}
else {
// Command non-R
additionalField = { k: model_options.top_k };
//Does not support system messages
if (prompt.system && prompt.system?.length !== 0) {
prompt.messages?.push(converseSystemToMessages(prompt.system));
prompt.system = undefined;
prompt.messages = converseConcatMessages(prompt.messages);
}
}
}
else if (options.model.includes("palmyra")) {
const palmyraOptions = model_options;
additionalField = {
seed: palmyraOptions?.seed,
presence_penalty: palmyraOptions?.presence_penalty,
frequency_penalty: palmyraOptions?.frequency_penalty,
min_tokens: palmyraOptions?.min_tokens,
};
}
else if (options.model.includes("deepseek")) {
// DeepSeek models: no additional options, no stopSequences, only one of temperature/top_p
model_options.stop_sequence = undefined;
model_options.top_p = undefined;
}
else if (options.model.includes("gpt-oss")) {
const gptOssOptions = model_options;
additionalField = {
reasoning_effort: gptOssOptions?.reasoning_effort,
};
}
//If last message is "```json", add corresponding ``` as a stop sequence.
if (prompt.messages && prompt.messages.length > 0) {
if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
const stopSeq = model_options.stop_sequence;
if (!stopSeq) {
model_options.stop_sequence = ["```"];
}
else if (!stopSeq.includes("```")) {
stopSeq.push("```");
model_options.stop_sequence = stopSeq;
}
}
}
const tool_defs = getToolDefinitions(options.tools);
// Use prefill when there is a schema and tools are not being used
if (supportsJSONPrefill && options.result_schema && !tool_defs) {
prompt.messages = converseJSONprefill(prompt.messages);
}
// Clean undefined values from additionalField since AWS Bedrock requires valid JSON
// and will throw an exception for unrecognized parameters
const cleanedAdditionalFields = removeUndefinedValues(additionalField);
// Models with sampling parameter restrictions don't support temperature/top_p - exclude them from inference config
const cleanedModelOptions = removeUndefinedValues({
maxTokens: model_options.max_tokens,
...(hasSamplingRestriction ? {} : {
temperature: model_options.temperature,
topP: model_options.temperature != null ? undefined : model_options.top_p,
}),
stopSequences: model_options.stop_sequence,
});
//Construct the final request payload
// We only add fields that are defined to avoid AWS errors
const request = {
modelId: options.model,
};
if (prompt.messages) {
request.messages = prompt.messages;
}
if (prompt.system) {
request.system = prompt.system;
}
if (Object.keys(cleanedModelOptions).length > 0) {
request.inferenceConfig = cleanedModelOptions;
}
if (Object.keys(cleanedAdditionalFields).length > 0) {
request.additionalModelRequestFields = cleanedAdditionalFields;
}
if (tool_defs?.length) {
request.toolConfig = {
tools: tool_defs,
};
}
else if (request.messages && messagesContainToolBlocks(request.messages)) {
// Bedrock requires toolConfig when conversation contains toolUse/toolResult blocks.
// When no tools are provided (e.g. checkpoint summary calls), convert tool blocks
// to text representations so the conversation data is preserved while satisfying
// Bedrock's API requirements without making tools callable.
request.messages = convertToolBlocksToText(request.messages);
}
// Prompt caching: use three breakpoints so stable system blocks, tool definitions,
// and the conversation history prefix can all be reused across Claude turns.
if (options.model.includes('claude')) {
// Always strip stale markers from prior turns
if (request.messages) {
request.messages = stripClaudeCachePoints(request.messages);
}
request.system = stripClaudeCachePointsFromSystem(request.system);
if (request.toolConfig?.tools) {
request.toolConfig = {
...request.toolConfig,
tools: stripClaudeCachePointsFromTools(request.toolConfig.tools),
};
}
const claudeOptions = model_options;
const cacheEnabled = claudeOptions?.cache_enabled === true;
if (cacheEnabled) {
const cacheTtl = claudeOptions?.cache_ttl;
const cachePointBlock = { type: 'default', ...(cacheTtl && { ttl: cacheTtl }) };
if (request.system && request.system.length > 0) {
request.system = [...request.system, { cachePoint: cachePointBlock }];
}
if (request.toolConfig?.tools && request.toolConfig.tools.length > 0) {
request.toolConfig.tools = [
...request.toolConfig.tools,
{ cachePoint: cachePointBlock },
];
}
if (request.messages && request.messages.length >= 4) {
const pivotMsg = request.messages[request.messages.length - 2];
if (pivotMsg.content && Array.isArray(pivotMsg.content) && pivotMsg.content.length > 0) {
pivotMsg.content = [...pivotMsg.content, { cachePoint: cachePointBlock }];
}
}
}
}
return request;
}
isImageModel(model) {
return model.includes("titan-image") || model.includes("stable-diffusion") || model.includes("nova-canvas");
}
async requestImageGeneration(prompt, options) {
if (options.model_options?._option_id !== undefined && options.model_options?._option_id !== "bedrock-nova-canvas") {
this.logger.debug({ options: options.model_options }, "Unexpected option id");
}
const model_options = options.model_options;
const executor = this.getExecutor();
const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE;
this.logger.info("Task type: " + taskType);
if (typeof prompt === "string") {
throw new Error("Bad prompt format");
}
const payload = await formatNovaImageGenerationPayload(taskType, prompt, options);
const res = await executor.invokeModel({
modelId: options.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(payload),
}, {
requestTimeout: 60000 * 5
});
const decoder = new TextDecoder();
const body = decoder.decode(res.body);
const bedrockResult = JSON.parse(body);
return {
error: bedrockResult.error,
result: bedrockResult.images.map((image) => ({
type: "image",
value: image
}))
};
}
async startTraining(dataset, options) {
//convert options.params to Record<string, string>
const params = {};
for (const [key, value] of Object.entries(options.params || {})) {
params[key] = String(value);
}
if (!this.options.training_bucket) {
throw new Error("Training cannot nbe used since the 'training_bucket' property was not specified in driver options");
}
const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials });
const stream = await dataset.getStream();
const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name);
const service = this.getService();
const response = await service.send(new CreateModelCustomizationJobCommand({
jobName: options.name + "-job",
customModelName: options.name,
roleArn: this.options.training_role_arn || undefined,
baseModelIdentifier: options.model,
clientRequestToken: "llumiverse-" + Date.now(),
trainingDataConfig: {
s3Uri: `s3://${upload.Bucket}/${upload.Key}`,
},
outputDataConfig: undefined,
hyperParameters: params,
//TODO not supported?
//customizationType: "FINE_TUNING",
}));
const job = await service.send(new GetModelCustomizationJobCommand({
jobIdentifier: response.jobArn
}));
return jobInfo(job, response.jobArn);
}
async cancelTraining(jobId) {
const service = this.getService();
await service.send(new StopModelCustomizationJobCommand({
jobIdentifier: jobId
}));
const job = await service.send(new GetModelCustomizationJobCommand({
jobIdentifier: jobId
}));
return jobInfo(job, jobId);
}
async getTrainingJob(jobId) {
const service = this.getService();
const job = await service.send(new GetModelCustomizationJobCommand({
jobIdentifier: jobId
}));
return jobInfo(job, jobId);
}
// ===================== management API ==================
async validateConnection() {
const service = this.getService();
this.logger.debug("[Bedrock] validating connection", service.config.credentials.name);
//return true as if the client has been initialized, it means the connection is valid
return true;
}
async listTrainableModels() {
this.logger.debug("[Bedrock] listing trainable models");
return this._listModels(m => m.customizationsSupported ? m.customizationsSupported.includes("FINE_TUNING") : false);
}
async listModels() {
this.logger.debug("[Bedrock] listing models");
// exclude trainable models since they are not executable
// exclude embedding models, not to be used for typical completions.
const filter = (m) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false;
return this._listModels(filter);
}
async _listModels(foundationFilter) {
const service = this.getService();
const [foundationModelsList, customModelsList, inferenceProfilesList] = await Promise.all([
service.listFoundationModels({}).catch(() => {
this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
return undefined;
}),
service.listCustomModels({}).catch(() => {
this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
return undefined;
}),
service.listInferenceProfiles({}).catch(() => {
this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions.");
return undefined;
}),
]);
if (!foundationModelsList?.modelSummaries) {
throw new Error("Foundation models not found");
}
let foundationModels = foundationModelsList.modelSummaries || [];
if (foundationFilter) {
foundationModels = foundationModels.filter(foundationFilter);
}
const supportedPublishers = ["amazon", "anthropic", "cohere", "ai21",
"mistral", "meta", "deepseek", "writer",
"openai", "twelvelabs", "qwen"];
const unsupportedModelsByPublisher = {
amazon: ["titan-image-generator", "nova-reel", "nova-sonic", "rerank"],
anthr