@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
1,132 lines (1,000 loc) • 74.7 kB
text/typescript
import {
Bedrock, CreateModelCustomizationJobCommand, type FoundationModelSummary, GetModelCustomizationJobCommand,
type GetModelCustomizationJobCommandOutput, ModelCustomizationJobStatus, ModelModality, StopModelCustomizationJobCommand
} from "@aws-sdk/client-bedrock";
import { BedrockRuntime, type ContentBlock, type ConverseRequest, type ConverseResponse, type ConverseStreamOutput, type InferenceConfiguration, type Message, type Tool } from "@aws-sdk/client-bedrock-runtime";
import { S3Client } from "@aws-sdk/client-s3";
import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types";
import {
AbstractDriver, type AIModel,
type BedrockClaudeOptions,
type BedrockGptOssOptions,
type BedrockPalmyraOptions,
type Completion, type CompletionChunkObject,
type CompletionResult,
type DataSource,
deserializeBinaryFromStorage,
type DriverOptions,
type EmbeddingsOptions, type EmbeddingsResult,
type ExecutionOptions, type ExecutionTokenUsage,
getConversationMeta,
getMaxTokensLimitBedrock,
getModelCapabilities,
incrementConversationTurn,
isClaudeVersionGTE,
LlumiverseError, type LlumiverseErrorContext,
modelModalitiesToArray,
type ModelOptions,
type NovaCanvasOptions,
type PromptSegment,
type StatelessExecutionOptions,
stripBinaryFromConversation,
stripHeartbeatsFromConversation,
type TextFallbackOptions, type ToolDefinition, type ToolUse, type TrainingJob, TrainingJobStatus, type TrainingOptions,
truncateLargeTextInConversation
} from "@llumiverse/core";
import { transformAsyncIterator } from "@llumiverse/core/async";
import { formatNovaPrompt, type NovaMessagesPrompt } from "@llumiverse/core/formatters";
import { LRUCache } from "mnemonist";
import { resolveClaudeThinking } from "../shared/claude-thinking.js";
import { converseConcatMessages, converseJSONprefill, converseSystemToMessages, formatConversePrompt } from "./converse.js";
import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js";
import { forceUploadFile } from "./s3.js";
import {
formatTwelvelabsPegasusPrompt,
type TwelvelabsMarengoRequest,
type TwelvelabsMarengoResponse,
type TwelvelabsPegasusRequest
} from "./twelvelabs.js";
const supportStreamingCache = new LRUCache<string, boolean>(4096);
enum BedrockModelType {
FoundationModel = "foundation-model",
InferenceProfile = "inference-profile",
CustomModel = "custom-model",
Unknown = "unknown",
};
function converseFinishReason(reason: string | undefined) {
//Possible values:
//end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
if (!reason) return undefined;
switch (reason) {
case 'end_turn': return "stop";
case 'max_tokens': return "length";
default: return reason;
}
}
export interface BedrockModelCapabilities {
name: string;
canStream: boolean;
}
export interface BedrockDriverOptions extends DriverOptions {
/**
* The AWS region
*/
region: string;
/**
* The bucket name to be used for training.
* It will be created if does not already exist.
*/
training_bucket?: string;
/**
* The role ARN to be used for training
*/
training_role_arn?: string;
/**
* The credentials to use to access AWS (IAM access key + secret)
*/
credentials?: AwsCredentialIdentity | Provider<AwsCredentialIdentity>;
}
//Used to get a max_token value when not specified in the model options. Claude requires it to be set.
function maxTokenFallbackClaude(option: StatelessExecutionOptions): number {
const modelOptions = option.model_options as BedrockClaudeOptions | undefined;
if (modelOptions && typeof modelOptions.max_tokens === "number") {
return modelOptions.max_tokens;
} else {
let maxSupportedTokens = getMaxTokensLimitBedrock(option.model) ?? 8192; // Should always return a number for claude, 8192 is to satisfy the TypeScript type checker;
// Fallback to the default max tokens limit for the model
if (option.model.includes('claude-3-7-sonnet') && (modelOptions?.thinking_budget_tokens ?? 0) < 48000) {
maxSupportedTokens = 64000; // Claude 3.7 can go up to 128k with a beta header, but when no max tokens is specified, we default to 64k.
}
return maxSupportedTokens;
}
}
export type BedrockPrompt = NovaMessagesPrompt | ConverseRequest | TwelvelabsPegasusRequest;
type BedrockSystemBlock = NonNullable<ConverseRequest['system']>[number];
type BedrockToolEntry = NonNullable<NonNullable<ConverseRequest['toolConfig']>['tools']>[number];
export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockPrompt> {
static PROVIDER = "bedrock";
provider = BedrockDriver.PROVIDER;
private _executor?: BedrockRuntime;
private _service?: Bedrock;
private _service_region?: string;
constructor(options: BedrockDriverOptions) {
super(options);
if (!options.region) {
throw new Error("No region found. Set the region in the environment's endpoint URL.");
}
}
getExecutor() {
if (!this._executor) {
this._executor = new BedrockRuntime({
region: this.options.region,
credentials: this.options.credentials,
});
}
return this._executor;
}
getService(region: string = this.options.region) {
if (!this._service || this._service_region !== region) {
this._service = new Bedrock({
region: region,
credentials: this.options.credentials,
});
this._service_region = region;
}
return this._service;
}
protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<BedrockPrompt> {
if (opts.model.includes("canvas")) {
return await formatNovaPrompt(segments, opts.result_schema);
}
if (opts.model.includes("twelvelabs.pegasus")) {
return await formatTwelvelabsPegasusPrompt(segments, opts);
}
return await formatConversePrompt(segments, opts);
}
/**
* Format AWS Bedrock errors into LlumiverseError with proper status codes and retryability.
*
* AWS SDK errors provide:
* - error.name: The exception type (e.g., "ThrottlingException")
* - error.$metadata.httpStatusCode: The HTTP status code
* - error.$metadata.requestId: The AWS request ID for tracking
* - error.$fault: "client" or "server" indicating error category
*
* @param error - The AWS SDK error
* @param context - Context about where the error occurred
* @returns A standardized LlumiverseError
*/
public formatLlumiverseError(
error: unknown,
context: LlumiverseErrorContext
): LlumiverseError {
// Check if it's an AWS SDK error with $metadata
const awsError = error as any;
const hasMetadata = awsError?.$metadata !== undefined;
if (!hasMetadata) {
// Not an AWS SDK error, use default handling
return super.formatLlumiverseError(error, context);
}
// Extract AWS-specific fields
const errorName = awsError.name || 'UnknownError';
const httpStatusCode = awsError.$metadata?.httpStatusCode;
const requestId = awsError.$metadata?.requestId;
const fault = awsError.$fault; // "client" or "server"
// Extract error message - handle both Error instances and plain objects
let message: string;
if (error instanceof Error) {
message = error.message;
} else if (typeof awsError.message === 'string') {
message = awsError.message;
} else {
message = String(error);
}
// Build user-facing message with error name and status code
let userMessage = message;
// Include status code in message if available (for end-user visibility)
if (httpStatusCode) {
userMessage = `[${httpStatusCode}] ${userMessage}`;
}
// Prefix with error name if it's meaningful (not just "Error")
if (errorName && errorName !== 'Error' && errorName !== 'UnknownError') {
userMessage = `${errorName}: ${userMessage}`;
}
// Add request ID if available (useful for AWS support)
if (requestId) {
userMessage += ` (Request ID: ${requestId})`;
}
// Determine retryability based on AWS error types
const retryable = this.isBedrockErrorRetryable(errorName, httpStatusCode, fault);
return new LlumiverseError(
`[${this.provider}] ${userMessage}`,
retryable,
context,
error,
httpStatusCode, // Only set code if we have numeric status code
errorName // Preserve AWS error name
);
}
/**
* Determine if a Bedrock error is retryable based on error type and status.
*
* Retryable errors:
* - ThrottlingException: Rate limit exceeded, retry with backoff
* - ServiceUnavailableException: Service temporarily down
* - InternalServerException: Server-side error
* - ServiceQuotaExceededException: Quota exhausted, may recover
* - 5xx status codes: Server errors
* - 429, 408 status codes: Rate limit, timeout
*
* Non-retryable errors:
* - ValidationException: Invalid request parameters
* - AccessDeniedException: Authentication/authorization failure
* - ResourceNotFoundException: Resource doesn't exist
* - ConflictException: Resource state conflict
* - ResourceInUseException: Resource locked by another operation
* - 4xx status codes (except 429, 408): Client errors
*
* @param errorName - The AWS error name (e.g., "ThrottlingException")
* @param httpStatusCode - The HTTP status code if available
* @param fault - The fault type ("client" or "server")
* @returns True if retryable, false if not retryable, undefined if unknown
*/
private isBedrockErrorRetryable(
errorName: string,
httpStatusCode: number | undefined,
fault: string | undefined
): boolean | undefined {
// Check specific AWS error types first
switch (errorName) {
// Retryable errors
case 'ThrottlingException':
case 'ServiceUnavailableException':
case 'InternalServerException':
case 'ServiceQuotaExceededException':
return true;
// Non-retryable errors
case 'ValidationException':
case 'AccessDeniedException':
case 'ResourceNotFoundException':
case 'ConflictException':
case 'ResourceInUseException':
case 'TooManyTagsException':
return false;
}
// If we have HTTP status code, use it
if (httpStatusCode !== undefined) {
if (httpStatusCode === 429 || httpStatusCode === 408) return true; // Rate limit, timeout
if (httpStatusCode === 529) return true; // Overloaded
if (httpStatusCode >= 500 && httpStatusCode < 600) return true; // Server errors
if (httpStatusCode >= 400 && httpStatusCode < 500) return false; // Client errors
}
// Fall back to fault type
if (fault === 'server') return true;
if (fault === 'client') return false;
// Unknown error type - let consumer decide retry strategy
return undefined;
}
getExtractedExecution(result: ConverseResponse, _prompt?: BedrockPrompt, options?: ExecutionOptions): CompletionChunkObject {
let resultText = "";
let reasoning = "";
if (result.output?.message?.content) {
for (const content of result.output.message.content) {
// Get text output
if (content.text) {
resultText += content.text;
} else if (content.reasoningContent) {
// Extract reasoning content if include_thoughts is true, or if it's a
// reasoning-only model (e.g. DeepSeek R1) that returns no text blocks
const claudeOptions = options?.model_options as BedrockClaudeOptions;
const isReasoningModel = options?.model?.includes('deepseek') && options?.model?.includes('r1');
if (claudeOptions?.include_thoughts || isReasoningModel) {
if (content.reasoningContent.reasoningText) {
reasoning += content.reasoningContent.reasoningText.text;
} else if (content.reasoningContent.redactedContent) {
// Handle redacted thinking content
const redactedData = new TextDecoder().decode(content.reasoningContent.redactedContent);
reasoning += `[Redacted thinking: ${redactedData}]`;
}
} else {
this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false");
}
} else {
// Get content block type
const type = Object.keys(content).find(
key => key !== '$unknown' && content[key as keyof typeof content] !== undefined
);
this.logger.info({ type }, "[Bedrock] Unsupported content response type:");
}
}
// Add spacing if we have reasoning content
if (reasoning) {
reasoning += '\n\n';
}
}
const completionResult: CompletionChunkObject = {
result: reasoning + resultText ? [{ type: "text", value: reasoning + resultText }] : [],
token_usage: {
// Bedrock's inputTokens already excludes cache-read tokens,
// so prompt_new is inputTokens directly (no subtraction needed).
// prompt is the total including cached + cache_write for consistency
// with the Vertex Claude driver.
prompt_new: result.usage?.inputTokens,
prompt: result.usage ? (result.usage.inputTokens ?? 0) + (result.usage.cacheReadInputTokens ?? 0) + (result.usage.cacheWriteInputTokens ?? 0) : undefined,
result: result.usage?.outputTokens,
total: result.usage?.totalTokens,
prompt_cached: result.usage?.cacheReadInputTokens ?? undefined,
prompt_cache_write: result.usage?.cacheWriteInputTokens ?? undefined,
},
finish_reason: converseFinishReason(result.stopReason),
};
return completionResult;
};
getExtractedStream(result: ConverseStreamOutput, _prompt?: BedrockPrompt, options?: ExecutionOptions, streamingToolBlocks?: Map<number, { id: string; name: string }>): CompletionChunkObject {
let output: string = "";
let reasoning: string = "";
let stop_reason = "";
let token_usage: ExecutionTokenUsage | undefined;
let tool_use: ToolUse[] | undefined;
// Check if we should include thoughts (always true for reasoning-only models like DeepSeek R1)
const isReasoningModel = options?.model?.includes('deepseek') && options?.model?.includes('r1');
const shouldIncludeThoughts = isReasoningModel || (options && (options.model_options as BedrockClaudeOptions)?.include_thoughts);
// Handle content block start events (for reasoning blocks and tool use)
if (result.contentBlockStart) {
if (result.contentBlockStart.start && 'toolUse' in result.contentBlockStart.start && result.contentBlockStart.start.toolUse) {
// Register new tool call block and emit an initial chunk so the accumulator can track it by id
const toolUseStart = result.contentBlockStart.start.toolUse;
const blockIndex = result.contentBlockStart.contentBlockIndex ?? -1;
const id = toolUseStart.toolUseId ?? '';
const name = toolUseStart.name ?? '';
streamingToolBlocks?.set(blockIndex, { id, name });
tool_use = [{ id, tool_name: name, tool_input: '' as any }];
} else if (result.contentBlockStart.start && 'reasoningContent' in result.contentBlockStart.start && shouldIncludeThoughts) {
// Handle redacted content at block start
const reasoningStart = result.contentBlockStart.start as any;
if (reasoningStart.reasoningContent?.redactedContent) {
const redactedData = new TextDecoder().decode(reasoningStart.reasoningContent.redactedContent);
reasoning = `[Redacted thinking: ${redactedData}]`;
}
}
}
// Handle content block deltas (text, reasoning, and tool use)
if (result.contentBlockDelta) {
const delta = result.contentBlockDelta.delta;
if (delta?.toolUse) {
// Emit tool input chunk; the accumulator in DefaultCompletionStream concatenates these strings
const blockIndex = result.contentBlockDelta.contentBlockIndex ?? -1;
const toolBlock = streamingToolBlocks?.get(blockIndex);
if (toolBlock && delta.toolUse.input !== undefined) {
tool_use = [{ id: toolBlock.id, tool_name: '', tool_input: delta.toolUse.input as any }];
}
} else if (delta?.text) {
output = delta.text;
} else if (delta?.reasoningContent && shouldIncludeThoughts) {
if (delta.reasoningContent.text) {
reasoning = delta.reasoningContent.text;
} else if (delta.reasoningContent.redactedContent) {
const redactedData = new TextDecoder().decode(delta.reasoningContent.redactedContent);
reasoning = `[Redacted thinking: ${redactedData}]`;
} else if (delta.reasoningContent.signature) {
// Handle signature updates for reasoning content - end of thinking
reasoning = "\n\n";
// Putting logging here so it only triggers once.
this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false");
}
} else if (delta) {
// Get content block type
const type = Object.keys(delta).find(
key => key !== '$unknown' && (delta as any)[key] !== undefined
);
this.logger.info({ type }, "[Bedrock] Unsupported content response type:");
}
}
// Handle content block stop events
if (result.contentBlockStop) {
// Clean up tool block tracking entry
const blockIndex = result.contentBlockStop.contentBlockIndex ?? -1;
streamingToolBlocks?.delete(blockIndex);
// Add minimal spacing for reasoning blocks if not already present
if (reasoning && !reasoning.endsWith('\n\n') && shouldIncludeThoughts) {
reasoning += '\n\n';
}
}
if (result.messageStop) {
stop_reason = result.messageStop.stopReason ?? "";
}
if (result.metadata) {
token_usage = {
prompt_new: result.metadata.usage?.inputTokens,
prompt: result.metadata.usage ? (result.metadata.usage.inputTokens ?? 0) + (result.metadata.usage.cacheReadInputTokens ?? 0) + (result.metadata.usage.cacheWriteInputTokens ?? 0) : undefined,
result: result.metadata.usage?.outputTokens,
total: result.metadata.usage?.totalTokens,
prompt_cached: result.metadata.usage?.cacheReadInputTokens ?? undefined,
prompt_cache_write: result.metadata.usage?.cacheWriteInputTokens ?? undefined,
}
}
const completionResult: CompletionChunkObject = {
result: reasoning + output ? [{ type: "text", value: reasoning + output }] : [],
token_usage: token_usage,
finish_reason: converseFinishReason(stop_reason),
tool_use,
};
return completionResult;
};
extractRegion(modelString: string, defaultRegion: string): string {
// Match region in full ARN pattern
const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/);
if (arnMatch) {
return arnMatch[1];
}
// Match common AWS regions directly in string
const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/);
if (regionMatch) {
return regionMatch[0];
}
return defaultRegion;
}
private async getCanStream(model: string, type: BedrockModelType): Promise<boolean> {
let canStream: boolean = false;
let error: any = null;
const region = this.extractRegion(model, this.options.region);
if (type === BedrockModelType.FoundationModel || type === BedrockModelType.Unknown) {
try {
const response = await this.getService(region).getFoundationModel({
modelIdentifier: model
});
canStream = response.modelDetails?.responseStreamingSupported ?? false;
return canStream;
} catch (e) {
error = e;
}
}
if (type === BedrockModelType.InferenceProfile || type === BedrockModelType.Unknown) {
try {
const response = await this.getService(region).getInferenceProfile({
inferenceProfileIdentifier: model
});
canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel);
return canStream;
} catch (e) {
error = e;
}
}
if (type === BedrockModelType.CustomModel || type === BedrockModelType.Unknown) {
try {
const response = await this.getService(region).getCustomModel({
modelIdentifier: model
});
canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel);
return canStream;
} catch (e) {
error = e;
}
}
if (error) {
console.warn("Error on canStream check for model: " + model + " region detected: " + region, error);
}
return canStream;
}
protected async canStream(options: ExecutionOptions): Promise<boolean> {
// // TwelveLabs Pegasus supports streaming according to the documentation
// if (options.model.includes("twelvelabs.pegasus")) {
// return true;
// }
let canStream = supportStreamingCache.get(options.model);
if (canStream == null) {
let type = BedrockModelType.Unknown;
if (options.model.includes("foundation-model")) {
type = BedrockModelType.FoundationModel;
} else if (options.model.includes("inference-profile")) {
type = BedrockModelType.InferenceProfile;
} else if (options.model.includes("custom-model")) {
type = BedrockModelType.CustomModel;
}
canStream = await this.getCanStream(options.model, type);
supportStreamingCache.set(options.model, canStream);
}
return canStream;
}
/**
* Build conversation context after streaming completion.
* Reconstructs the assistant message from accumulated results and applies stripping.
*/
buildStreamingConversation(
prompt: BedrockPrompt,
result: unknown[],
toolUse: unknown[] | undefined,
options: ExecutionOptions
): ConverseRequest | undefined {
// Only handle ConverseRequest prompts (not NovaMessagesPrompt or TwelvelabsPegasusRequest)
if (options.model.includes("canvas") || options.model.includes("twelvelabs.pegasus")) {
return undefined;
}
const conversePrompt = prompt as ConverseRequest;
const completionResults = result as CompletionResult[];
// Convert accumulated results to text content for assistant message
const textContent = completionResults
.map(r => {
switch (r.type) {
case 'text':
return r.value;
case 'json':
return typeof r.value === 'string' ? r.value : JSON.stringify(r.value);
case 'image':
// Skip images in conversation - they're in the result
return '';
default:
return String((r as any).value || '');
}
})
.join('');
// Deserialize any base64-encoded binary data back to Uint8Array
const incomingConversation = deserializeBinaryFromStorage(options.conversation) as ConverseRequest;
// Start with the conversation from options combined with the prompt
let conversation = updateConversation(incomingConversation, conversePrompt);
// Build assistant message content
const messageContent: any[] = [];
if (textContent) {
messageContent.push({ text: textContent });
}
// Add tool use blocks if present
if (toolUse && toolUse.length > 0) {
for (const tool of toolUse as ToolUse[]) {
messageContent.push({
toolUse: {
toolUseId: tool.id,
name: tool.tool_name,
input: tool.tool_input,
}
});
}
}
// Add assistant message
const assistantMessage: ConverseRequest = {
messages: [{
content: messageContent.length > 0 ? messageContent : [{ text: '' }],
role: "assistant"
}],
modelId: conversePrompt.modelId,
};
conversation = updateConversation(conversation, assistantMessage);
// Increment turn counter
conversation = incrementConversationTurn(conversation) as ConverseRequest;
// Apply stripping based on options
const currentTurn = getConversationMeta(conversation).turnNumber;
const stripOptions = {
keepForTurns: options.stripImagesAfterTurns ?? Infinity,
currentTurn,
textMaxTokens: options.stripTextMaxTokens
};
let processedConversation = stripBinaryFromConversation(conversation, stripOptions);
processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions);
processedConversation = stripHeartbeatsFromConversation(processedConversation, {
keepForTurns: options.stripHeartbeatsAfterTurns ?? 1,
currentTurn,
});
return processedConversation as ConverseRequest;
}
async requestTextCompletion(prompt: BedrockPrompt, options: ExecutionOptions): Promise<Completion> {
// Handle Twelvelabs Pegasus models
if (options.model.includes("twelvelabs.pegasus")) {
return this.requestTwelvelabsPegasusCompletion(prompt as TwelvelabsPegasusRequest, options);
}
// Handle other Bedrock models that use Converse API
const conversePrompt = prompt as ConverseRequest;
// Deserialize any base64-encoded binary data back to Uint8Array before API call
const incomingConversation = deserializeBinaryFromStorage(options.conversation) as ConverseRequest;
let conversation = updateConversation(incomingConversation, conversePrompt);
const payload = this.preparePayload(conversation, options);
const executor = this.getExecutor();
const res = await executor.converse({
...payload,
});
// Strip reasoningContent from assistant messages before storing in conversation
// (DeepSeek R1 returns reasoning blocks but rejects them in subsequent user turns)
const assistantMsg = res.output?.message ?? { content: [{ text: "" }], role: "assistant" };
if (assistantMsg.content) {
assistantMsg.content = assistantMsg.content.filter((c: any) => !c.reasoningContent);
}
conversation = updateConversation(conversation, {
messages: [assistantMsg],
modelId: conversePrompt.modelId,
});
// Increment turn counter for deferred stripping
conversation = incrementConversationTurn(conversation) as ConverseRequest;
let tool_use: ToolUse[] | undefined = undefined;
//Get tool requests, we check tool use regardless of finish reason, as you can hit length and still get a valid response.
tool_use = res.output?.message?.content?.reduce((tools: ToolUse[], c) => {
if (c.toolUse) {
tools.push({
tool_name: c.toolUse.name ?? "",
tool_input: c.toolUse.input as any,
id: c.toolUse.toolUseId ?? "",
} satisfies ToolUse);
}
return tools;
}, []);
//If no tools were used, set to undefined
if (tool_use && tool_use.length === 0) {
tool_use = undefined;
}
// Strip/serialize binary data based on options.stripImagesAfterTurns
const currentTurn = getConversationMeta(conversation).turnNumber;
const stripOptions = {
keepForTurns: options.stripImagesAfterTurns ?? Infinity,
currentTurn,
textMaxTokens: options.stripTextMaxTokens
};
let processedConversation = stripBinaryFromConversation(conversation, stripOptions);
// Truncate large text content if configured
processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions);
// Strip old heartbeat status messages
processedConversation = stripHeartbeatsFromConversation(processedConversation, {
keepForTurns: options.stripHeartbeatsAfterTurns ?? 1,
currentTurn,
});
const completion = {
...this.getExtractedExecution(res, conversePrompt, options),
original_response: options.include_original_response ? res : undefined,
conversation: processedConversation,
tool_use: tool_use,
};
return completion;
}
private async requestTwelvelabsPegasusCompletion(prompt: TwelvelabsPegasusRequest, options: ExecutionOptions): Promise<Completion> {
const executor = this.getExecutor();
const res = await executor.invokeModel({
modelId: options.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(prompt),
});
const decoder = new TextDecoder();
const body = decoder.decode(res.body);
const result = JSON.parse(body);
// Extract the response according to TwelveLabs Pegasus format
let finishReason: string | undefined;
switch (result.finishReason) {
case "stop":
finishReason = "stop";
break;
case "length":
finishReason = "length";
break;
default:
finishReason = result.finishReason;
}
return {
result: result.message ? [{ type: "text" as const, value: result.message }] : [],
finish_reason: finishReason,
original_response: options.include_original_response ? result : undefined,
};
}
private async requestTwelvelabsPegasusCompletionStream(prompt: TwelvelabsPegasusRequest, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
const executor = this.getExecutor();
const res = await executor.invokeModelWithResponseStream({
modelId: options.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(prompt),
});
if (!res.body) {
throw new Error("[Bedrock] Stream not found in response");
}
return transformAsyncIterator(res.body, (chunk: any) => {
if (chunk.chunk?.bytes) {
const decoder = new TextDecoder();
const body = decoder.decode(chunk.chunk.bytes);
try {
const result = JSON.parse(body);
// Extract streaming response according to TwelveLabs Pegasus format
let finishReason: string | undefined;
if (result.finishReason) {
switch (result.finishReason) {
case "stop":
finishReason = "stop";
break;
case "length":
finishReason = "length";
break;
default:
finishReason = result.finishReason;
}
}
return {
result: result.delta || result.message ? [{ type: "text" as const, value: result.delta || result.message || "" }] : [],
finish_reason: finishReason,
} satisfies CompletionChunkObject;
} catch (error) {
// If JSON parsing fails, return empty chunk
return {
result: [],
} satisfies CompletionChunkObject;
}
}
return {
result: [],
} satisfies CompletionChunkObject;
});
}
async requestTextCompletionStream(prompt: BedrockPrompt, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
// Handle Twelvelabs Pegasus models
if (options.model.includes("twelvelabs.pegasus")) {
return this.requestTwelvelabsPegasusCompletionStream(prompt as TwelvelabsPegasusRequest, options);
}
// Handle other Bedrock models that use Converse API
const conversePrompt = prompt as ConverseRequest;
// Include conversation history (same as non-streaming)
// Deserialize any base64-encoded binary data back to Uint8Array before API call
const incomingConversation = deserializeBinaryFromStorage(options.conversation) as ConverseRequest;
const conversation = updateConversation(incomingConversation, conversePrompt);
const payload = this.preparePayload(conversation, options);
const executor = this.getExecutor();
return executor.converseStream({
...payload,
}).then((res) => {
const stream = res.stream;
if (!stream) {
throw new Error("[Bedrock] Stream not found in response");
}
const streamingToolBlocks = new Map<number, { id: string; name: string }>();
return transformAsyncIterator(stream, (streamSegment: ConverseStreamOutput) => {
return this.getExtractedStream(streamSegment, conversePrompt, options, streamingToolBlocks);
});
}).catch((err) => {
this.logger.error({ error: err }, "[Bedrock] Failed to stream");
throw err;
});
}
preparePayload(prompt: ConverseRequest, options: ExecutionOptions) {
const model_options: TextFallbackOptions = options.model_options as TextFallbackOptions ?? { _option_id: "text-fallback" };
let additionalField = {};
let supportsJSONPrefill = false;
// Resolve thinking, effort, and sampling restrictions using shared Claude helper
const claudeThinking = resolveClaudeThinking(options.model, options.model_options as BedrockClaudeOptions | undefined);
const hasSamplingRestriction = claudeThinking.hasSamplingRestriction;
if (options.model.includes("amazon")) {
supportsJSONPrefill = true;
//Titan models also exists but does not support any additional options
if (options.model.includes("nova")) {
additionalField = { inferenceConfig: { topK: model_options.top_k } };
}
} else if (options.model.includes("claude")) {
const claude_options = model_options as ModelOptions as BedrockClaudeOptions;
// Thinking is active when extended (budget set) or adaptive (effort set) thinking is enabled.
// JSON prefill is incompatible with active thinking.
const thinkingActive = claudeThinking.thinking != null && claudeThinking.thinking.type !== "disabled";
supportsJSONPrefill = !thinkingActive
// Claude 3.7+ supports thinking — use shared helper for reasoning_config
if (claudeThinking.supportsThinking) {
if (claudeThinking.thinking) {
additionalField = {
...additionalField,
reasoning_config: claudeThinking.thinking,
};
}
// For Claude 3.7 with extended thinking + high output, add beta header
if (claudeThinking.thinking?.type === "enabled" &&
options.model.includes("claude-3-7-sonnet") &&
((claude_options.max_tokens ?? 0) > 64000 || (claude_options.thinking_budget_tokens ?? 0) > 64000)) {
additionalField = {
...additionalField,
anthropic_beta: ["output-128k-2025-02-19"]
};
}
}
// Add effort parameter via output_config (Opus 4.5+, Sonnet 4.6+, all 4.7+)
if (claudeThinking.outputConfig) {
additionalField = {
...additionalField,
output_config: claudeThinking.outputConfig
};
}
// Claude 4.6 and later versions don't support JSON prefill
if (isClaudeVersionGTE(options.model, 4, 6)) {
supportsJSONPrefill = false;
}
// Needs max_tokens to be set
if (!model_options.max_tokens) {
model_options.max_tokens = maxTokenFallbackClaude(options);
}
// Only models without sampling restrictions support top_k
if (!hasSamplingRestriction) {
additionalField = { ...additionalField, top_k: model_options.top_k };
}
} else if (options.model.includes("meta")) {
//LLaMA models support no additional options
} else if (options.model.includes("mistral")) {
//7B instruct and 8x7B instruct
if (options.model.includes("7b")) {
additionalField = { top_k: model_options.top_k };
//Does not support system messages
if (prompt.system && prompt.system?.length !== 0) {
prompt.messages?.push(converseSystemToMessages(prompt.system));
prompt.system = undefined;
prompt.messages = converseConcatMessages(prompt.messages);
}
} else {
//Other models such as Mistral Small,Large and Large 2
//Support no additional fields.
}
} else if (options.model.includes("ai21")) {
//Jamba models support no additional options
//Jurassic 2 models do.
if (options.model.includes("j2")) {
additionalField = {
presencePenalty: { scale: model_options.presence_penalty },
frequencyPenalty: { scale: model_options.frequency_penalty },
};
//Does not support system messages
if (prompt.system && prompt.system?.length !== 0) {
prompt.messages?.push(converseSystemToMessages(prompt.system));
prompt.system = undefined;
prompt.messages = converseConcatMessages(prompt.messages);
}
}
} else if (options.model.includes("cohere.command")) {
// If last message is "```json", remove it.
//Command R and R plus
if (options.model.includes("cohere.command-r")) {
additionalField = {
k: model_options.top_k,
frequency_penalty: model_options.frequency_penalty,
presence_penalty: model_options.presence_penalty,
};
} else {
// Command non-R
additionalField = { k: model_options.top_k };
//Does not support system messages
if (prompt.system && prompt.system?.length !== 0) {
prompt.messages?.push(converseSystemToMessages(prompt.system));
prompt.system = undefined;
prompt.messages = converseConcatMessages(prompt.messages);
}
}
} else if (options.model.includes("palmyra")) {
const palmyraOptions = model_options as ModelOptions as BedrockPalmyraOptions;
additionalField = {
seed: palmyraOptions?.seed,
presence_penalty: palmyraOptions?.presence_penalty,
frequency_penalty: palmyraOptions?.frequency_penalty,
min_tokens: palmyraOptions?.min_tokens,
}
} else if (options.model.includes("deepseek")) {
// DeepSeek models: no additional options, no stopSequences, only one of temperature/top_p
model_options.stop_sequence = undefined;
model_options.top_p = undefined;
} else if (options.model.includes("gpt-oss")) {
const gptOssOptions = model_options as ModelOptions as BedrockGptOssOptions;
additionalField = {
reasoning_effort: gptOssOptions?.reasoning_effort,
};
}
//If last message is "```json", add corresponding ``` as a stop sequence.
if (prompt.messages && prompt.messages.length > 0) {
if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
const stopSeq = model_options.stop_sequence;
if (!stopSeq) {
model_options.stop_sequence = ["```"];
} else if (!stopSeq.includes("```")) {
stopSeq.push("```");
model_options.stop_sequence = stopSeq;
}
}
}
const tool_defs = getToolDefinitions(options.tools);
// Use prefill when there is a schema and tools are not being used
if (supportsJSONPrefill && options.result_schema && !tool_defs) {
prompt.messages = converseJSONprefill(prompt.messages);
}
// Clean undefined values from additionalField since AWS Bedrock requires valid JSON
// and will throw an exception for unrecognized parameters
const cleanedAdditionalFields = removeUndefinedValues(additionalField);
// Models with sampling parameter restrictions don't support temperature/top_p - exclude them from inference config
const cleanedModelOptions = removeUndefinedValues({
maxTokens: model_options.max_tokens,
...(hasSamplingRestriction ? {} : {
temperature: model_options.temperature,
topP: model_options.temperature != null ? undefined : model_options.top_p,
}),
stopSequences: model_options.stop_sequence,
} satisfies InferenceConfiguration);
//Construct the final request payload
// We only add fields that are defined to avoid AWS errors
const request: ConverseRequest = {
modelId: options.model,
};
if (prompt.messages) {
request.messages = prompt.messages;
}
if (prompt.system) {
request.system = prompt.system;
}
if (Object.keys(cleanedModelOptions).length > 0) {
request.inferenceConfig = cleanedModelOptions
}
if (Object.keys(cleanedAdditionalFields).length > 0) {
request.additionalModelRequestFields = cleanedAdditionalFields;
}
if (tool_defs?.length) {
request.toolConfig = {
tools: tool_defs,
}
} else if (request.messages && messagesContainToolBlocks(request.messages)) {
// Bedrock requires toolConfig when conversation contains toolUse/toolResult blocks.
// When no tools are provided (e.g. checkpoint summary calls), convert tool blocks
// to text representations so the conversation data is preserved while satisfying
// Bedrock's API requirements without making tools callable.
request.messages = convertToolBlocksToText(request.messages);
}
// Prompt caching: use three breakpoints so stable system blocks, tool definitions,
// and the conversation history prefix can all be reused across Claude turns.
if (options.model.includes('claude')) {
// Always strip stale markers from prior turns
if (request.messages) {
request.messages = stripClaudeCachePoints(request.messages);
}
request.system = stripClaudeCachePointsFromSystem(request.system);
if (request.toolConfig?.tools) {
request.toolConfig = {
...request.toolConfig,
tools: stripClaudeCachePointsFromTools(request.toolConfig.tools),
};
}
const claudeOptions = model_options as unknown as BedrockClaudeOptions;
const cacheEnabled = claudeOptions?.cache_enabled === true;
if (cacheEnabled) {
const cacheTtl = claudeOptions?.cache_ttl;
const cachePointBlock = { type: 'default' as const, ...(cacheTtl && { ttl: cacheTtl }) };
if (request.system && request.system.length > 0) {
request.system = [...request.system, { cachePoint: cachePointBlock } satisfies BedrockSystemBlock];
}
if (request.toolConfig?.tools && request.toolConfig.tools.length > 0) {
request.toolConfig.tools = [
...request.toolConfig.tools,
{ cachePoint: cachePointBlock } satisfies BedrockToolEntry,
];
}
if (request.messages && request.messages.length >= 4) {
const pivotMsg = request.messages[request.messages.length - 2];
if (pivotMsg.content && Array.isArray(pivotMsg.content) && pivotMsg.content.length > 0) {
pivotMsg.content = [...pivotMsg.content, { cachePoint: cachePointBlock }];
}
}
}
}
return request;
}
protected isImageModel(model: string): boolean {
return model.includes("titan-image") || model.includes("stable-diffusion") || model.includes("nova-canvas");
}
async requestImageGeneration(prompt: NovaMessagesPrompt, options: ExecutionOptions): Promise<Completion> {
if (options.model_options?._option_id !== undefined && options.model_options?._option_id !== "bedrock-nova-canvas") {
this.logger.debug({ options: options.model_options }, "Unexpected option id");
}
const model_options = options.model_options as NovaCanvasOptions;
const executor = this.getExecutor();
const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE;
this.logger.info("Task type: " + taskType);
if (typeof prompt === "string") {
throw new Error("Bad prompt format");
}
const payload = await formatNovaImageGenerationPayload(taskType, prompt, options);
const res = await executor.invokeModel({
modelId: options.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(payload),
},
{
requestTimeout: 60000 * 5
});
const decoder = new TextDecoder();
const body = decoder.decode(res.body);
const bedrockResult = JSON.parse(body);
return {
error: bedrockResult.error,
result: bedrockResult.images.map((image: any) => ({
type: "image" as const,
value: image
}))
}
}
async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> {
//convert options.params to Record<string, string>
const params: Record<string, string> = {};
for (const [key, value] of Object.entries(options.params || {})) {
params[key] = String(value);
}
if (!this.options.training_bucket) {
throw new Error("Training cannot nbe used