UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

1,132 lines (1,000 loc) 74.7 kB
import { Bedrock, CreateModelCustomizationJobCommand, type FoundationModelSummary, GetModelCustomizationJobCommand, type GetModelCustomizationJobCommandOutput, ModelCustomizationJobStatus, ModelModality, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock"; import { BedrockRuntime, type ContentBlock, type ConverseRequest, type ConverseResponse, type ConverseStreamOutput, type InferenceConfiguration, type Message, type Tool } from "@aws-sdk/client-bedrock-runtime"; import { S3Client } from "@aws-sdk/client-s3"; import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types"; import { AbstractDriver, type AIModel, type BedrockClaudeOptions, type BedrockGptOssOptions, type BedrockPalmyraOptions, type Completion, type CompletionChunkObject, type CompletionResult, type DataSource, deserializeBinaryFromStorage, type DriverOptions, type EmbeddingsOptions, type EmbeddingsResult, type ExecutionOptions, type ExecutionTokenUsage, getConversationMeta, getMaxTokensLimitBedrock, getModelCapabilities, incrementConversationTurn, isClaudeVersionGTE, LlumiverseError, type LlumiverseErrorContext, modelModalitiesToArray, type ModelOptions, type NovaCanvasOptions, type PromptSegment, type StatelessExecutionOptions, stripBinaryFromConversation, stripHeartbeatsFromConversation, type TextFallbackOptions, type ToolDefinition, type ToolUse, type TrainingJob, TrainingJobStatus, type TrainingOptions, truncateLargeTextInConversation } from "@llumiverse/core"; import { transformAsyncIterator } from "@llumiverse/core/async"; import { formatNovaPrompt, type NovaMessagesPrompt } from "@llumiverse/core/formatters"; import { LRUCache } from "mnemonist"; import { resolveClaudeThinking } from "../shared/claude-thinking.js"; import { converseConcatMessages, converseJSONprefill, converseSystemToMessages, formatConversePrompt } from "./converse.js"; import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js"; import { forceUploadFile } from "./s3.js"; import { formatTwelvelabsPegasusPrompt, type TwelvelabsMarengoRequest, type TwelvelabsMarengoResponse, type TwelvelabsPegasusRequest } from "./twelvelabs.js"; const supportStreamingCache = new LRUCache<string, boolean>(4096); enum BedrockModelType { FoundationModel = "foundation-model", InferenceProfile = "inference-profile", CustomModel = "custom-model", Unknown = "unknown", }; function converseFinishReason(reason: string | undefined) { //Possible values: //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered if (!reason) return undefined; switch (reason) { case 'end_turn': return "stop"; case 'max_tokens': return "length"; default: return reason; } } export interface BedrockModelCapabilities { name: string; canStream: boolean; } export interface BedrockDriverOptions extends DriverOptions { /** * The AWS region */ region: string; /** * The bucket name to be used for training. * It will be created if does not already exist. */ training_bucket?: string; /** * The role ARN to be used for training */ training_role_arn?: string; /** * The credentials to use to access AWS (IAM access key + secret) */ credentials?: AwsCredentialIdentity | Provider<AwsCredentialIdentity>; } //Used to get a max_token value when not specified in the model options. Claude requires it to be set. function maxTokenFallbackClaude(option: StatelessExecutionOptions): number { const modelOptions = option.model_options as BedrockClaudeOptions | undefined; if (modelOptions && typeof modelOptions.max_tokens === "number") { return modelOptions.max_tokens; } else { let maxSupportedTokens = getMaxTokensLimitBedrock(option.model) ?? 8192; // Should always return a number for claude, 8192 is to satisfy the TypeScript type checker; // Fallback to the default max tokens limit for the model if (option.model.includes('claude-3-7-sonnet') && (modelOptions?.thinking_budget_tokens ?? 0) < 48000) { maxSupportedTokens = 64000; // Claude 3.7 can go up to 128k with a beta header, but when no max tokens is specified, we default to 64k. } return maxSupportedTokens; } } export type BedrockPrompt = NovaMessagesPrompt | ConverseRequest | TwelvelabsPegasusRequest; type BedrockSystemBlock = NonNullable<ConverseRequest['system']>[number]; type BedrockToolEntry = NonNullable<NonNullable<ConverseRequest['toolConfig']>['tools']>[number]; export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockPrompt> { static PROVIDER = "bedrock"; provider = BedrockDriver.PROVIDER; private _executor?: BedrockRuntime; private _service?: Bedrock; private _service_region?: string; constructor(options: BedrockDriverOptions) { super(options); if (!options.region) { throw new Error("No region found. Set the region in the environment's endpoint URL."); } } getExecutor() { if (!this._executor) { this._executor = new BedrockRuntime({ region: this.options.region, credentials: this.options.credentials, }); } return this._executor; } getService(region: string = this.options.region) { if (!this._service || this._service_region !== region) { this._service = new Bedrock({ region: region, credentials: this.options.credentials, }); this._service_region = region; } return this._service; } protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<BedrockPrompt> { if (opts.model.includes("canvas")) { return await formatNovaPrompt(segments, opts.result_schema); } if (opts.model.includes("twelvelabs.pegasus")) { return await formatTwelvelabsPegasusPrompt(segments, opts); } return await formatConversePrompt(segments, opts); } /** * Format AWS Bedrock errors into LlumiverseError with proper status codes and retryability. * * AWS SDK errors provide: * - error.name: The exception type (e.g., "ThrottlingException") * - error.$metadata.httpStatusCode: The HTTP status code * - error.$metadata.requestId: The AWS request ID for tracking * - error.$fault: "client" or "server" indicating error category * * @param error - The AWS SDK error * @param context - Context about where the error occurred * @returns A standardized LlumiverseError */ public formatLlumiverseError( error: unknown, context: LlumiverseErrorContext ): LlumiverseError { // Check if it's an AWS SDK error with $metadata const awsError = error as any; const hasMetadata = awsError?.$metadata !== undefined; if (!hasMetadata) { // Not an AWS SDK error, use default handling return super.formatLlumiverseError(error, context); } // Extract AWS-specific fields const errorName = awsError.name || 'UnknownError'; const httpStatusCode = awsError.$metadata?.httpStatusCode; const requestId = awsError.$metadata?.requestId; const fault = awsError.$fault; // "client" or "server" // Extract error message - handle both Error instances and plain objects let message: string; if (error instanceof Error) { message = error.message; } else if (typeof awsError.message === 'string') { message = awsError.message; } else { message = String(error); } // Build user-facing message with error name and status code let userMessage = message; // Include status code in message if available (for end-user visibility) if (httpStatusCode) { userMessage = `[${httpStatusCode}] ${userMessage}`; } // Prefix with error name if it's meaningful (not just "Error") if (errorName && errorName !== 'Error' && errorName !== 'UnknownError') { userMessage = `${errorName}: ${userMessage}`; } // Add request ID if available (useful for AWS support) if (requestId) { userMessage += ` (Request ID: ${requestId})`; } // Determine retryability based on AWS error types const retryable = this.isBedrockErrorRetryable(errorName, httpStatusCode, fault); return new LlumiverseError( `[${this.provider}] ${userMessage}`, retryable, context, error, httpStatusCode, // Only set code if we have numeric status code errorName // Preserve AWS error name ); } /** * Determine if a Bedrock error is retryable based on error type and status. * * Retryable errors: * - ThrottlingException: Rate limit exceeded, retry with backoff * - ServiceUnavailableException: Service temporarily down * - InternalServerException: Server-side error * - ServiceQuotaExceededException: Quota exhausted, may recover * - 5xx status codes: Server errors * - 429, 408 status codes: Rate limit, timeout * * Non-retryable errors: * - ValidationException: Invalid request parameters * - AccessDeniedException: Authentication/authorization failure * - ResourceNotFoundException: Resource doesn't exist * - ConflictException: Resource state conflict * - ResourceInUseException: Resource locked by another operation * - 4xx status codes (except 429, 408): Client errors * * @param errorName - The AWS error name (e.g., "ThrottlingException") * @param httpStatusCode - The HTTP status code if available * @param fault - The fault type ("client" or "server") * @returns True if retryable, false if not retryable, undefined if unknown */ private isBedrockErrorRetryable( errorName: string, httpStatusCode: number | undefined, fault: string | undefined ): boolean | undefined { // Check specific AWS error types first switch (errorName) { // Retryable errors case 'ThrottlingException': case 'ServiceUnavailableException': case 'InternalServerException': case 'ServiceQuotaExceededException': return true; // Non-retryable errors case 'ValidationException': case 'AccessDeniedException': case 'ResourceNotFoundException': case 'ConflictException': case 'ResourceInUseException': case 'TooManyTagsException': return false; } // If we have HTTP status code, use it if (httpStatusCode !== undefined) { if (httpStatusCode === 429 || httpStatusCode === 408) return true; // Rate limit, timeout if (httpStatusCode === 529) return true; // Overloaded if (httpStatusCode >= 500 && httpStatusCode < 600) return true; // Server errors if (httpStatusCode >= 400 && httpStatusCode < 500) return false; // Client errors } // Fall back to fault type if (fault === 'server') return true; if (fault === 'client') return false; // Unknown error type - let consumer decide retry strategy return undefined; } getExtractedExecution(result: ConverseResponse, _prompt?: BedrockPrompt, options?: ExecutionOptions): CompletionChunkObject { let resultText = ""; let reasoning = ""; if (result.output?.message?.content) { for (const content of result.output.message.content) { // Get text output if (content.text) { resultText += content.text; } else if (content.reasoningContent) { // Extract reasoning content if include_thoughts is true, or if it's a // reasoning-only model (e.g. DeepSeek R1) that returns no text blocks const claudeOptions = options?.model_options as BedrockClaudeOptions; const isReasoningModel = options?.model?.includes('deepseek') && options?.model?.includes('r1'); if (claudeOptions?.include_thoughts || isReasoningModel) { if (content.reasoningContent.reasoningText) { reasoning += content.reasoningContent.reasoningText.text; } else if (content.reasoningContent.redactedContent) { // Handle redacted thinking content const redactedData = new TextDecoder().decode(content.reasoningContent.redactedContent); reasoning += `[Redacted thinking: ${redactedData}]`; } } else { this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false"); } } else { // Get content block type const type = Object.keys(content).find( key => key !== '$unknown' && content[key as keyof typeof content] !== undefined ); this.logger.info({ type }, "[Bedrock] Unsupported content response type:"); } } // Add spacing if we have reasoning content if (reasoning) { reasoning += '\n\n'; } } const completionResult: CompletionChunkObject = { result: reasoning + resultText ? [{ type: "text", value: reasoning + resultText }] : [], token_usage: { // Bedrock's inputTokens already excludes cache-read tokens, // so prompt_new is inputTokens directly (no subtraction needed). // prompt is the total including cached + cache_write for consistency // with the Vertex Claude driver. prompt_new: result.usage?.inputTokens, prompt: result.usage ? (result.usage.inputTokens ?? 0) + (result.usage.cacheReadInputTokens ?? 0) + (result.usage.cacheWriteInputTokens ?? 0) : undefined, result: result.usage?.outputTokens, total: result.usage?.totalTokens, prompt_cached: result.usage?.cacheReadInputTokens ?? undefined, prompt_cache_write: result.usage?.cacheWriteInputTokens ?? undefined, }, finish_reason: converseFinishReason(result.stopReason), }; return completionResult; }; getExtractedStream(result: ConverseStreamOutput, _prompt?: BedrockPrompt, options?: ExecutionOptions, streamingToolBlocks?: Map<number, { id: string; name: string }>): CompletionChunkObject { let output: string = ""; let reasoning: string = ""; let stop_reason = ""; let token_usage: ExecutionTokenUsage | undefined; let tool_use: ToolUse[] | undefined; // Check if we should include thoughts (always true for reasoning-only models like DeepSeek R1) const isReasoningModel = options?.model?.includes('deepseek') && options?.model?.includes('r1'); const shouldIncludeThoughts = isReasoningModel || (options && (options.model_options as BedrockClaudeOptions)?.include_thoughts); // Handle content block start events (for reasoning blocks and tool use) if (result.contentBlockStart) { if (result.contentBlockStart.start && 'toolUse' in result.contentBlockStart.start && result.contentBlockStart.start.toolUse) { // Register new tool call block and emit an initial chunk so the accumulator can track it by id const toolUseStart = result.contentBlockStart.start.toolUse; const blockIndex = result.contentBlockStart.contentBlockIndex ?? -1; const id = toolUseStart.toolUseId ?? ''; const name = toolUseStart.name ?? ''; streamingToolBlocks?.set(blockIndex, { id, name }); tool_use = [{ id, tool_name: name, tool_input: '' as any }]; } else if (result.contentBlockStart.start && 'reasoningContent' in result.contentBlockStart.start && shouldIncludeThoughts) { // Handle redacted content at block start const reasoningStart = result.contentBlockStart.start as any; if (reasoningStart.reasoningContent?.redactedContent) { const redactedData = new TextDecoder().decode(reasoningStart.reasoningContent.redactedContent); reasoning = `[Redacted thinking: ${redactedData}]`; } } } // Handle content block deltas (text, reasoning, and tool use) if (result.contentBlockDelta) { const delta = result.contentBlockDelta.delta; if (delta?.toolUse) { // Emit tool input chunk; the accumulator in DefaultCompletionStream concatenates these strings const blockIndex = result.contentBlockDelta.contentBlockIndex ?? -1; const toolBlock = streamingToolBlocks?.get(blockIndex); if (toolBlock && delta.toolUse.input !== undefined) { tool_use = [{ id: toolBlock.id, tool_name: '', tool_input: delta.toolUse.input as any }]; } } else if (delta?.text) { output = delta.text; } else if (delta?.reasoningContent && shouldIncludeThoughts) { if (delta.reasoningContent.text) { reasoning = delta.reasoningContent.text; } else if (delta.reasoningContent.redactedContent) { const redactedData = new TextDecoder().decode(delta.reasoningContent.redactedContent); reasoning = `[Redacted thinking: ${redactedData}]`; } else if (delta.reasoningContent.signature) { // Handle signature updates for reasoning content - end of thinking reasoning = "\n\n"; // Putting logging here so it only triggers once. this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false"); } } else if (delta) { // Get content block type const type = Object.keys(delta).find( key => key !== '$unknown' && (delta as any)[key] !== undefined ); this.logger.info({ type }, "[Bedrock] Unsupported content response type:"); } } // Handle content block stop events if (result.contentBlockStop) { // Clean up tool block tracking entry const blockIndex = result.contentBlockStop.contentBlockIndex ?? -1; streamingToolBlocks?.delete(blockIndex); // Add minimal spacing for reasoning blocks if not already present if (reasoning && !reasoning.endsWith('\n\n') && shouldIncludeThoughts) { reasoning += '\n\n'; } } if (result.messageStop) { stop_reason = result.messageStop.stopReason ?? ""; } if (result.metadata) { token_usage = { prompt_new: result.metadata.usage?.inputTokens, prompt: result.metadata.usage ? (result.metadata.usage.inputTokens ?? 0) + (result.metadata.usage.cacheReadInputTokens ?? 0) + (result.metadata.usage.cacheWriteInputTokens ?? 0) : undefined, result: result.metadata.usage?.outputTokens, total: result.metadata.usage?.totalTokens, prompt_cached: result.metadata.usage?.cacheReadInputTokens ?? undefined, prompt_cache_write: result.metadata.usage?.cacheWriteInputTokens ?? undefined, } } const completionResult: CompletionChunkObject = { result: reasoning + output ? [{ type: "text", value: reasoning + output }] : [], token_usage: token_usage, finish_reason: converseFinishReason(stop_reason), tool_use, }; return completionResult; }; extractRegion(modelString: string, defaultRegion: string): string { // Match region in full ARN pattern const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/); if (arnMatch) { return arnMatch[1]; } // Match common AWS regions directly in string const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/); if (regionMatch) { return regionMatch[0]; } return defaultRegion; } private async getCanStream(model: string, type: BedrockModelType): Promise<boolean> { let canStream: boolean = false; let error: any = null; const region = this.extractRegion(model, this.options.region); if (type === BedrockModelType.FoundationModel || type === BedrockModelType.Unknown) { try { const response = await this.getService(region).getFoundationModel({ modelIdentifier: model }); canStream = response.modelDetails?.responseStreamingSupported ?? false; return canStream; } catch (e) { error = e; } } if (type === BedrockModelType.InferenceProfile || type === BedrockModelType.Unknown) { try { const response = await this.getService(region).getInferenceProfile({ inferenceProfileIdentifier: model }); canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel); return canStream; } catch (e) { error = e; } } if (type === BedrockModelType.CustomModel || type === BedrockModelType.Unknown) { try { const response = await this.getService(region).getCustomModel({ modelIdentifier: model }); canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel); return canStream; } catch (e) { error = e; } } if (error) { console.warn("Error on canStream check for model: " + model + " region detected: " + region, error); } return canStream; } protected async canStream(options: ExecutionOptions): Promise<boolean> { // // TwelveLabs Pegasus supports streaming according to the documentation // if (options.model.includes("twelvelabs.pegasus")) { // return true; // } let canStream = supportStreamingCache.get(options.model); if (canStream == null) { let type = BedrockModelType.Unknown; if (options.model.includes("foundation-model")) { type = BedrockModelType.FoundationModel; } else if (options.model.includes("inference-profile")) { type = BedrockModelType.InferenceProfile; } else if (options.model.includes("custom-model")) { type = BedrockModelType.CustomModel; } canStream = await this.getCanStream(options.model, type); supportStreamingCache.set(options.model, canStream); } return canStream; } /** * Build conversation context after streaming completion. * Reconstructs the assistant message from accumulated results and applies stripping. */ buildStreamingConversation( prompt: BedrockPrompt, result: unknown[], toolUse: unknown[] | undefined, options: ExecutionOptions ): ConverseRequest | undefined { // Only handle ConverseRequest prompts (not NovaMessagesPrompt or TwelvelabsPegasusRequest) if (options.model.includes("canvas") || options.model.includes("twelvelabs.pegasus")) { return undefined; } const conversePrompt = prompt as ConverseRequest; const completionResults = result as CompletionResult[]; // Convert accumulated results to text content for assistant message const textContent = completionResults .map(r => { switch (r.type) { case 'text': return r.value; case 'json': return typeof r.value === 'string' ? r.value : JSON.stringify(r.value); case 'image': // Skip images in conversation - they're in the result return ''; default: return String((r as any).value || ''); } }) .join(''); // Deserialize any base64-encoded binary data back to Uint8Array const incomingConversation = deserializeBinaryFromStorage(options.conversation) as ConverseRequest; // Start with the conversation from options combined with the prompt let conversation = updateConversation(incomingConversation, conversePrompt); // Build assistant message content const messageContent: any[] = []; if (textContent) { messageContent.push({ text: textContent }); } // Add tool use blocks if present if (toolUse && toolUse.length > 0) { for (const tool of toolUse as ToolUse[]) { messageContent.push({ toolUse: { toolUseId: tool.id, name: tool.tool_name, input: tool.tool_input, } }); } } // Add assistant message const assistantMessage: ConverseRequest = { messages: [{ content: messageContent.length > 0 ? messageContent : [{ text: '' }], role: "assistant" }], modelId: conversePrompt.modelId, }; conversation = updateConversation(conversation, assistantMessage); // Increment turn counter conversation = incrementConversationTurn(conversation) as ConverseRequest; // Apply stripping based on options const currentTurn = getConversationMeta(conversation).turnNumber; const stripOptions = { keepForTurns: options.stripImagesAfterTurns ?? Infinity, currentTurn, textMaxTokens: options.stripTextMaxTokens }; let processedConversation = stripBinaryFromConversation(conversation, stripOptions); processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions); processedConversation = stripHeartbeatsFromConversation(processedConversation, { keepForTurns: options.stripHeartbeatsAfterTurns ?? 1, currentTurn, }); return processedConversation as ConverseRequest; } async requestTextCompletion(prompt: BedrockPrompt, options: ExecutionOptions): Promise<Completion> { // Handle Twelvelabs Pegasus models if (options.model.includes("twelvelabs.pegasus")) { return this.requestTwelvelabsPegasusCompletion(prompt as TwelvelabsPegasusRequest, options); } // Handle other Bedrock models that use Converse API const conversePrompt = prompt as ConverseRequest; // Deserialize any base64-encoded binary data back to Uint8Array before API call const incomingConversation = deserializeBinaryFromStorage(options.conversation) as ConverseRequest; let conversation = updateConversation(incomingConversation, conversePrompt); const payload = this.preparePayload(conversation, options); const executor = this.getExecutor(); const res = await executor.converse({ ...payload, }); // Strip reasoningContent from assistant messages before storing in conversation // (DeepSeek R1 returns reasoning blocks but rejects them in subsequent user turns) const assistantMsg = res.output?.message ?? { content: [{ text: "" }], role: "assistant" }; if (assistantMsg.content) { assistantMsg.content = assistantMsg.content.filter((c: any) => !c.reasoningContent); } conversation = updateConversation(conversation, { messages: [assistantMsg], modelId: conversePrompt.modelId, }); // Increment turn counter for deferred stripping conversation = incrementConversationTurn(conversation) as ConverseRequest; let tool_use: ToolUse[] | undefined = undefined; //Get tool requests, we check tool use regardless of finish reason, as you can hit length and still get a valid response. tool_use = res.output?.message?.content?.reduce((tools: ToolUse[], c) => { if (c.toolUse) { tools.push({ tool_name: c.toolUse.name ?? "", tool_input: c.toolUse.input as any, id: c.toolUse.toolUseId ?? "", } satisfies ToolUse); } return tools; }, []); //If no tools were used, set to undefined if (tool_use && tool_use.length === 0) { tool_use = undefined; } // Strip/serialize binary data based on options.stripImagesAfterTurns const currentTurn = getConversationMeta(conversation).turnNumber; const stripOptions = { keepForTurns: options.stripImagesAfterTurns ?? Infinity, currentTurn, textMaxTokens: options.stripTextMaxTokens }; let processedConversation = stripBinaryFromConversation(conversation, stripOptions); // Truncate large text content if configured processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions); // Strip old heartbeat status messages processedConversation = stripHeartbeatsFromConversation(processedConversation, { keepForTurns: options.stripHeartbeatsAfterTurns ?? 1, currentTurn, }); const completion = { ...this.getExtractedExecution(res, conversePrompt, options), original_response: options.include_original_response ? res : undefined, conversation: processedConversation, tool_use: tool_use, }; return completion; } private async requestTwelvelabsPegasusCompletion(prompt: TwelvelabsPegasusRequest, options: ExecutionOptions): Promise<Completion> { const executor = this.getExecutor(); const res = await executor.invokeModel({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(prompt), }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const result = JSON.parse(body); // Extract the response according to TwelveLabs Pegasus format let finishReason: string | undefined; switch (result.finishReason) { case "stop": finishReason = "stop"; break; case "length": finishReason = "length"; break; default: finishReason = result.finishReason; } return { result: result.message ? [{ type: "text" as const, value: result.message }] : [], finish_reason: finishReason, original_response: options.include_original_response ? result : undefined, }; } private async requestTwelvelabsPegasusCompletionStream(prompt: TwelvelabsPegasusRequest, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> { const executor = this.getExecutor(); const res = await executor.invokeModelWithResponseStream({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(prompt), }); if (!res.body) { throw new Error("[Bedrock] Stream not found in response"); } return transformAsyncIterator(res.body, (chunk: any) => { if (chunk.chunk?.bytes) { const decoder = new TextDecoder(); const body = decoder.decode(chunk.chunk.bytes); try { const result = JSON.parse(body); // Extract streaming response according to TwelveLabs Pegasus format let finishReason: string | undefined; if (result.finishReason) { switch (result.finishReason) { case "stop": finishReason = "stop"; break; case "length": finishReason = "length"; break; default: finishReason = result.finishReason; } } return { result: result.delta || result.message ? [{ type: "text" as const, value: result.delta || result.message || "" }] : [], finish_reason: finishReason, } satisfies CompletionChunkObject; } catch (error) { // If JSON parsing fails, return empty chunk return { result: [], } satisfies CompletionChunkObject; } } return { result: [], } satisfies CompletionChunkObject; }); } async requestTextCompletionStream(prompt: BedrockPrompt, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> { // Handle Twelvelabs Pegasus models if (options.model.includes("twelvelabs.pegasus")) { return this.requestTwelvelabsPegasusCompletionStream(prompt as TwelvelabsPegasusRequest, options); } // Handle other Bedrock models that use Converse API const conversePrompt = prompt as ConverseRequest; // Include conversation history (same as non-streaming) // Deserialize any base64-encoded binary data back to Uint8Array before API call const incomingConversation = deserializeBinaryFromStorage(options.conversation) as ConverseRequest; const conversation = updateConversation(incomingConversation, conversePrompt); const payload = this.preparePayload(conversation, options); const executor = this.getExecutor(); return executor.converseStream({ ...payload, }).then((res) => { const stream = res.stream; if (!stream) { throw new Error("[Bedrock] Stream not found in response"); } const streamingToolBlocks = new Map<number, { id: string; name: string }>(); return transformAsyncIterator(stream, (streamSegment: ConverseStreamOutput) => { return this.getExtractedStream(streamSegment, conversePrompt, options, streamingToolBlocks); }); }).catch((err) => { this.logger.error({ error: err }, "[Bedrock] Failed to stream"); throw err; }); } preparePayload(prompt: ConverseRequest, options: ExecutionOptions) { const model_options: TextFallbackOptions = options.model_options as TextFallbackOptions ?? { _option_id: "text-fallback" }; let additionalField = {}; let supportsJSONPrefill = false; // Resolve thinking, effort, and sampling restrictions using shared Claude helper const claudeThinking = resolveClaudeThinking(options.model, options.model_options as BedrockClaudeOptions | undefined); const hasSamplingRestriction = claudeThinking.hasSamplingRestriction; if (options.model.includes("amazon")) { supportsJSONPrefill = true; //Titan models also exists but does not support any additional options if (options.model.includes("nova")) { additionalField = { inferenceConfig: { topK: model_options.top_k } }; } } else if (options.model.includes("claude")) { const claude_options = model_options as ModelOptions as BedrockClaudeOptions; // Thinking is active when extended (budget set) or adaptive (effort set) thinking is enabled. // JSON prefill is incompatible with active thinking. const thinkingActive = claudeThinking.thinking != null && claudeThinking.thinking.type !== "disabled"; supportsJSONPrefill = !thinkingActive // Claude 3.7+ supports thinking — use shared helper for reasoning_config if (claudeThinking.supportsThinking) { if (claudeThinking.thinking) { additionalField = { ...additionalField, reasoning_config: claudeThinking.thinking, }; } // For Claude 3.7 with extended thinking + high output, add beta header if (claudeThinking.thinking?.type === "enabled" && options.model.includes("claude-3-7-sonnet") && ((claude_options.max_tokens ?? 0) > 64000 || (claude_options.thinking_budget_tokens ?? 0) > 64000)) { additionalField = { ...additionalField, anthropic_beta: ["output-128k-2025-02-19"] }; } } // Add effort parameter via output_config (Opus 4.5+, Sonnet 4.6+, all 4.7+) if (claudeThinking.outputConfig) { additionalField = { ...additionalField, output_config: claudeThinking.outputConfig }; } // Claude 4.6 and later versions don't support JSON prefill if (isClaudeVersionGTE(options.model, 4, 6)) { supportsJSONPrefill = false; } // Needs max_tokens to be set if (!model_options.max_tokens) { model_options.max_tokens = maxTokenFallbackClaude(options); } // Only models without sampling restrictions support top_k if (!hasSamplingRestriction) { additionalField = { ...additionalField, top_k: model_options.top_k }; } } else if (options.model.includes("meta")) { //LLaMA models support no additional options } else if (options.model.includes("mistral")) { //7B instruct and 8x7B instruct if (options.model.includes("7b")) { additionalField = { top_k: model_options.top_k }; //Does not support system messages if (prompt.system && prompt.system?.length !== 0) { prompt.messages?.push(converseSystemToMessages(prompt.system)); prompt.system = undefined; prompt.messages = converseConcatMessages(prompt.messages); } } else { //Other models such as Mistral Small,Large and Large 2 //Support no additional fields. } } else if (options.model.includes("ai21")) { //Jamba models support no additional options //Jurassic 2 models do. if (options.model.includes("j2")) { additionalField = { presencePenalty: { scale: model_options.presence_penalty }, frequencyPenalty: { scale: model_options.frequency_penalty }, }; //Does not support system messages if (prompt.system && prompt.system?.length !== 0) { prompt.messages?.push(converseSystemToMessages(prompt.system)); prompt.system = undefined; prompt.messages = converseConcatMessages(prompt.messages); } } } else if (options.model.includes("cohere.command")) { // If last message is "```json", remove it. //Command R and R plus if (options.model.includes("cohere.command-r")) { additionalField = { k: model_options.top_k, frequency_penalty: model_options.frequency_penalty, presence_penalty: model_options.presence_penalty, }; } else { // Command non-R additionalField = { k: model_options.top_k }; //Does not support system messages if (prompt.system && prompt.system?.length !== 0) { prompt.messages?.push(converseSystemToMessages(prompt.system)); prompt.system = undefined; prompt.messages = converseConcatMessages(prompt.messages); } } } else if (options.model.includes("palmyra")) { const palmyraOptions = model_options as ModelOptions as BedrockPalmyraOptions; additionalField = { seed: palmyraOptions?.seed, presence_penalty: palmyraOptions?.presence_penalty, frequency_penalty: palmyraOptions?.frequency_penalty, min_tokens: palmyraOptions?.min_tokens, } } else if (options.model.includes("deepseek")) { // DeepSeek models: no additional options, no stopSequences, only one of temperature/top_p model_options.stop_sequence = undefined; model_options.top_p = undefined; } else if (options.model.includes("gpt-oss")) { const gptOssOptions = model_options as ModelOptions as BedrockGptOssOptions; additionalField = { reasoning_effort: gptOssOptions?.reasoning_effort, }; } //If last message is "```json", add corresponding ``` as a stop sequence. if (prompt.messages && prompt.messages.length > 0) { if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") { const stopSeq = model_options.stop_sequence; if (!stopSeq) { model_options.stop_sequence = ["```"]; } else if (!stopSeq.includes("```")) { stopSeq.push("```"); model_options.stop_sequence = stopSeq; } } } const tool_defs = getToolDefinitions(options.tools); // Use prefill when there is a schema and tools are not being used if (supportsJSONPrefill && options.result_schema && !tool_defs) { prompt.messages = converseJSONprefill(prompt.messages); } // Clean undefined values from additionalField since AWS Bedrock requires valid JSON // and will throw an exception for unrecognized parameters const cleanedAdditionalFields = removeUndefinedValues(additionalField); // Models with sampling parameter restrictions don't support temperature/top_p - exclude them from inference config const cleanedModelOptions = removeUndefinedValues({ maxTokens: model_options.max_tokens, ...(hasSamplingRestriction ? {} : { temperature: model_options.temperature, topP: model_options.temperature != null ? undefined : model_options.top_p, }), stopSequences: model_options.stop_sequence, } satisfies InferenceConfiguration); //Construct the final request payload // We only add fields that are defined to avoid AWS errors const request: ConverseRequest = { modelId: options.model, }; if (prompt.messages) { request.messages = prompt.messages; } if (prompt.system) { request.system = prompt.system; } if (Object.keys(cleanedModelOptions).length > 0) { request.inferenceConfig = cleanedModelOptions } if (Object.keys(cleanedAdditionalFields).length > 0) { request.additionalModelRequestFields = cleanedAdditionalFields; } if (tool_defs?.length) { request.toolConfig = { tools: tool_defs, } } else if (request.messages && messagesContainToolBlocks(request.messages)) { // Bedrock requires toolConfig when conversation contains toolUse/toolResult blocks. // When no tools are provided (e.g. checkpoint summary calls), convert tool blocks // to text representations so the conversation data is preserved while satisfying // Bedrock's API requirements without making tools callable. request.messages = convertToolBlocksToText(request.messages); } // Prompt caching: use three breakpoints so stable system blocks, tool definitions, // and the conversation history prefix can all be reused across Claude turns. if (options.model.includes('claude')) { // Always strip stale markers from prior turns if (request.messages) { request.messages = stripClaudeCachePoints(request.messages); } request.system = stripClaudeCachePointsFromSystem(request.system); if (request.toolConfig?.tools) { request.toolConfig = { ...request.toolConfig, tools: stripClaudeCachePointsFromTools(request.toolConfig.tools), }; } const claudeOptions = model_options as unknown as BedrockClaudeOptions; const cacheEnabled = claudeOptions?.cache_enabled === true; if (cacheEnabled) { const cacheTtl = claudeOptions?.cache_ttl; const cachePointBlock = { type: 'default' as const, ...(cacheTtl && { ttl: cacheTtl }) }; if (request.system && request.system.length > 0) { request.system = [...request.system, { cachePoint: cachePointBlock } satisfies BedrockSystemBlock]; } if (request.toolConfig?.tools && request.toolConfig.tools.length > 0) { request.toolConfig.tools = [ ...request.toolConfig.tools, { cachePoint: cachePointBlock } satisfies BedrockToolEntry, ]; } if (request.messages && request.messages.length >= 4) { const pivotMsg = request.messages[request.messages.length - 2]; if (pivotMsg.content && Array.isArray(pivotMsg.content) && pivotMsg.content.length > 0) { pivotMsg.content = [...pivotMsg.content, { cachePoint: cachePointBlock }]; } } } } return request; } protected isImageModel(model: string): boolean { return model.includes("titan-image") || model.includes("stable-diffusion") || model.includes("nova-canvas"); } async requestImageGeneration(prompt: NovaMessagesPrompt, options: ExecutionOptions): Promise<Completion> { if (options.model_options?._option_id !== undefined && options.model_options?._option_id !== "bedrock-nova-canvas") { this.logger.debug({ options: options.model_options }, "Unexpected option id"); } const model_options = options.model_options as NovaCanvasOptions; const executor = this.getExecutor(); const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE; this.logger.info("Task type: " + taskType); if (typeof prompt === "string") { throw new Error("Bad prompt format"); } const payload = await formatNovaImageGenerationPayload(taskType, prompt, options); const res = await executor.invokeModel({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(payload), }, { requestTimeout: 60000 * 5 }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const bedrockResult = JSON.parse(body); return { error: bedrockResult.error, result: bedrockResult.images.map((image: any) => ({ type: "image" as const, value: image })) } } async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> { //convert options.params to Record<string, string> const params: Record<string, string> = {}; for (const [key, value] of Object.entries(options.params || {})) { params[key] = String(value); } if (!this.options.training_bucket) { throw new Error("Training cannot nbe used