UNPKG

ai-functions

Version:

Core AI primitives for building intelligent applications

551 lines (483 loc) 17.8 kB
/** * AWS Bedrock Batch Inference Adapter * * Bedrock has a true batch inference API (S3-driven) and a runtime invoke API. * The "batch" adapter here uses concurrent runtime invocations as a fallback * (no S3 setup required); `createBedrockBatchJob` is exported separately for * callers who want to drive the real S3-based batch flow directly. * * @see https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference.html * * @packageDocumentation */ import { getLogger } from '../logger.js' import { LocalJobStore, processConcurrently, registerBatchAdapter, registerFlexAdapter, tryParseJson, type BatchAdapter, type BatchItem, type BatchJob, type BatchQueueOptions, type BatchResult, type BatchSubmitResult, type FlexAdapter, } from './provider.js' // ============================================================================ // Provider-specific types // ============================================================================ interface BedrockBatchRequest { recordId: string modelInput: { anthropic_version?: string max_tokens: number messages: Array<{ role: string; content: string }> system?: string temperature?: number } } // ============================================================================ // AWS configuration // ============================================================================ let awsRegion: string | undefined let awsAccessKeyId: string | undefined let awsSecretAccessKey: string | undefined let awsSessionToken: string | undefined let s3Bucket: string | undefined let roleArn: string | undefined let gatewayUrl: string | undefined let gatewayToken: string | undefined /** Configure AWS credentials and settings. */ export function configureAWSBedrock(options: { region?: string accessKeyId?: string secretAccessKey?: string sessionToken?: string s3Bucket?: string roleArn?: string /** Optional: Cloudflare AI Gateway URL for routing requests */ gatewayUrl?: string /** Optional: Cloudflare AI Gateway token */ gatewayToken?: string }): void { if (options.region) awsRegion = options.region if (options.accessKeyId) awsAccessKeyId = options.accessKeyId if (options.secretAccessKey) awsSecretAccessKey = options.secretAccessKey if (options.sessionToken) awsSessionToken = options.sessionToken if (options.s3Bucket) s3Bucket = options.s3Bucket if (options.roleArn) roleArn = options.roleArn if (options.gatewayUrl) gatewayUrl = options.gatewayUrl if (options.gatewayToken) gatewayToken = options.gatewayToken } interface BedrockConfig { region: string accessKeyId: string secretAccessKey: string sessionToken?: string | undefined bucket: string role: string | undefined gatewayUrl: string | undefined gatewayToken: string | undefined } function getConfig(): BedrockConfig { const region = awsRegion || process.env['AWS_REGION'] || process.env['AWS_DEFAULT_REGION'] || 'us-east-1' const accessKeyId = awsAccessKeyId || process.env['AWS_ACCESS_KEY_ID'] const secretAccessKey = awsSecretAccessKey || process.env['AWS_SECRET_ACCESS_KEY'] const sessionToken = awsSessionToken || process.env['AWS_SESSION_TOKEN'] const bucket = s3Bucket || process.env['BEDROCK_BATCH_S3_BUCKET'] const role = roleArn || process.env['BEDROCK_BATCH_ROLE_ARN'] const gwUrl = gatewayUrl || process.env['AI_GATEWAY_URL'] const gwToken = gatewayToken || process.env['AI_GATEWAY_TOKEN'] if (gwUrl && gwToken) { return { region, accessKeyId: accessKeyId || '', secretAccessKey: secretAccessKey || '', sessionToken, bucket: bucket || '', role, gatewayUrl: gwUrl, gatewayToken: gwToken, } } if (!accessKeyId || !secretAccessKey) { throw new Error( 'AWS credentials not configured. Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY, or use AI_GATEWAY_URL and AI_GATEWAY_TOKEN' ) } if (!bucket) { throw new Error('S3 bucket for Bedrock batch not configured. Set BEDROCK_BATCH_S3_BUCKET') } return { region, accessKeyId, secretAccessKey, sessionToken, bucket, role, gatewayUrl: undefined, gatewayToken: undefined, } } // ============================================================================ // AWS SigV4 (delegated to optional @smithy/signature-v4 if available) // ============================================================================ async function signRequest( method: string, url: string, body: string, config: BedrockConfig, service: string ): Promise<Headers> { const headers = new Headers({ 'Content-Type': 'application/json', 'X-Amz-Date': new Date().toISOString().replace(/[:-]|\.\d{3}/g, ''), }) if (config.sessionToken) { headers.set('X-Amz-Security-Token', config.sessionToken) } try { // Optional dependency — present in production, absent in dev/test. // @ts-expect-error - Optional dependency const signatureV4Module = await import('@smithy/signature-v4') // @ts-expect-error - Optional dependency const sha256Module = await import('@aws-crypto/sha256-js') const signer = new signatureV4Module.SignatureV4({ service, region: config.region, credentials: { accessKeyId: config.accessKeyId, secretAccessKey: config.secretAccessKey, sessionToken: config.sessionToken, }, sha256: sha256Module.Sha256, }) const signedRequest = await signer.sign({ method, headers: Object.fromEntries(headers.entries()), hostname: new URL(url).hostname, path: new URL(url).pathname, body, }) return new Headers(signedRequest.headers as Record<string, string>) } catch { getLogger().warn( 'AWS SDK not available for request signing. Install @smithy/signature-v4 and @aws-crypto/sha256-js' ) return headers } } // ============================================================================ // Local job tracking // ============================================================================ const jobs = new LocalJobStore('bedrock_batch') // ============================================================================ // Bedrock batch adapter (BatchProvider port) // ============================================================================ const bedrockAdapter: BatchAdapter = { async submit(items: BatchItem[], options: BatchQueueOptions): Promise<BatchSubmitResult> { const config = getConfig() const model = options.model || 'anthropic.claude-3-sonnet-20240229-v1:0' const { id, state } = jobs.create(items, options) // Drive the job state machine in the background. const completion = (async () => { state.status = 'in_progress' const results = await processConcurrently( items, (item) => processBedrockItem(item, config, model), { concurrency: 5, // Bedrock has stricter rate limits. delayBetweenWaves: 1000, onWaveComplete: (partial) => { state.results = partial }, } ) state.results = results state.status = results.every((r) => r.status === 'completed') ? 'completed' : 'failed' state.completedAt = new Date() return results })() const job: BatchJob = { id, provider: 'bedrock', status: 'pending', totalItems: items.length, completedItems: 0, failedItems: 0, createdAt: state.createdAt, ...(options.webhookUrl !== undefined && { webhookUrl: options.webhookUrl }), } return { job, completion } }, async getStatus(batchId: string): Promise<BatchJob> { return jobs.snapshot(batchId, 'bedrock') }, async cancel(batchId: string): Promise<void> { if (!jobs.has(batchId)) return const state = jobs.get(batchId) state.status = 'cancelled' const jobArn = state.meta?.['jobArn'] as string | undefined if (jobArn) { const config = getConfig() const url = `https://bedrock.${ config.region }.amazonaws.com/model-invocation-job/${encodeURIComponent(jobArn)}/stop` try { await fetch(url, { method: 'POST', headers: await signRequest('POST', url, '', config, 'bedrock'), }) } catch (error) { getLogger().warn('Failed to cancel Bedrock job:', error) } } }, async getResults(batchId: string): Promise<BatchResult[]> { return jobs.get(batchId).results }, async waitForCompletion(batchId: string, pollInterval = 5000): Promise<BatchResult[]> { return jobs.waitForCompletion(batchId, pollInterval) }, } // ============================================================================ // Per-item processing // ============================================================================ async function processBedrockItem( item: BatchItem, config: BedrockConfig, model: string ): Promise<BatchResult> { if (config.gatewayUrl && config.gatewayToken) { return processBedrockItemViaGateway(item, config, model) } const url = `https://bedrock-runtime.${config.region}.amazonaws.com/model/${encodeURIComponent( model )}/invoke` const body = buildBedrockRequestBody(item, model) const bodyStr = JSON.stringify(body) const headers = await signRequest('POST', url, bodyStr, config, 'bedrock') const response = await fetch(url, { method: 'POST', headers, body: bodyStr }) if (!response.ok) { const error = await response.text() throw new Error(`Bedrock API error: ${response.status} ${error}`) } return parseBedrockResponse(item, await response.json()) } /** * Process a Bedrock item via Cloudflare AI Gateway. * * Note: AI Gateway routes the request but doesn't handle authentication — * Bedrock still requires AWS SigV4 signing. * @see https://developers.cloudflare.com/ai-gateway/usage/providers/bedrock/ */ async function processBedrockItemViaGateway( item: BatchItem, config: BedrockConfig, model: string ): Promise<BatchResult> { const url = `${config.gatewayUrl}/aws-bedrock/bedrock-runtime/${ config.region }/model/${encodeURIComponent(model)}/invoke` const body: Record<string, unknown> = { anthropic_version: 'bedrock-2023-05-31', max_tokens: item.options?.maxTokens || 4096, messages: [{ role: 'user', content: item.prompt }], ...(item.options?.system !== undefined && { system: item.options.system }), ...(item.options?.temperature !== undefined && { temperature: item.options.temperature }), } const bodyStr = JSON.stringify(body) if (!config.accessKeyId || !config.secretAccessKey) { throw new Error( 'Bedrock via AI Gateway still requires AWS credentials for SigV4 signing. ' + 'Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY.' ) } const headers = await signRequest('POST', url, bodyStr, config, 'bedrock') headers.set('cf-aig-authorization', `Bearer ${config.gatewayToken}`) const response = await fetch(url, { method: 'POST', headers, body: bodyStr }) if (!response.ok) { const error = await response.text() throw new Error(`Bedrock via Gateway error: ${response.status} ${error}`) } return parseBedrockResponse(item, await response.json()) } /** Build the Bedrock invoke body for the model family. */ function buildBedrockRequestBody(item: BatchItem, model: string): Record<string, unknown> { if (model.includes('anthropic')) { return { anthropic_version: 'bedrock-2023-05-31', max_tokens: item.options?.maxTokens || 4096, messages: [{ role: 'user', content: item.prompt }], ...(item.options?.system !== undefined && { system: item.options.system }), ...(item.options?.temperature !== undefined && { temperature: item.options.temperature }), } } if (model.includes('amazon')) { return { inputText: item.prompt, textGenerationConfig: { maxTokenCount: item.options?.maxTokens || 4096, temperature: item.options?.temperature || 0.7, }, } } if (model.includes('meta')) { return { prompt: item.prompt, max_gen_len: item.options?.maxTokens || 4096, temperature: item.options?.temperature || 0.7, } } if (model.includes('mistral')) { return { prompt: `<s>[INST] ${item.prompt} [/INST]`, max_tokens: item.options?.maxTokens || 4096, temperature: item.options?.temperature || 0.7, } } // Default: Claude-style. return { anthropic_version: 'bedrock-2023-05-31', max_tokens: item.options?.maxTokens || 4096, messages: [{ role: 'user', content: item.prompt }], ...(item.options?.temperature !== undefined && { temperature: item.options.temperature }), } } /** Parse a Bedrock invoke response across model families. */ function parseBedrockResponse(item: BatchItem, raw: unknown): BatchResult { const data = raw as { content?: Array<{ type: string; text?: string }> usage?: { input_tokens: number; output_tokens: number } results?: Array<{ outputText: string; tokenCount: number }> generation?: string generation_token_count?: number prompt_token_count?: number } let content: string | undefined let usage: { promptTokens: number; completionTokens: number; totalTokens: number } | undefined if (data.content) { content = data.content.find((c) => c.type === 'text')?.text if (data.usage) { usage = { promptTokens: data.usage.input_tokens, completionTokens: data.usage.output_tokens, totalTokens: data.usage.input_tokens + data.usage.output_tokens, } } } else if (data.results?.[0]) { content = data.results[0].outputText usage = { promptTokens: 0, completionTokens: data.results[0].tokenCount || 0, totalTokens: data.results[0].tokenCount || 0, } } else if (data.generation) { content = data.generation if (data.generation_token_count !== undefined) { usage = { promptTokens: data.prompt_token_count || 0, completionTokens: data.generation_token_count, totalTokens: (data.prompt_token_count || 0) + data.generation_token_count, } } } return { id: item.id, customId: item.id, status: 'completed', result: tryParseJson(content, !!item.schema), ...(usage && { usage }), } } // ============================================================================ // True S3-based batch inference (separate from the BatchProvider adapter) // ============================================================================ /** * Create and submit a true Bedrock batch inference job. * Requires S3 bucket access and proper IAM setup. */ export async function createBedrockBatchJob( items: BatchItem[], model: string, options: { jobName: string s3InputPrefix?: string s3OutputPrefix?: string roleArn: string } ): Promise<{ jobArn: string }> { const config = getConfig() const jsonlLines = items.map((item) => { const request: BedrockBatchRequest = { recordId: item.id, modelInput: { anthropic_version: 'bedrock-2023-05-31', max_tokens: item.options?.maxTokens || 4096, messages: [{ role: 'user', content: item.prompt }], ...(item.options?.system !== undefined && { system: item.options.system }), ...(item.options?.temperature !== undefined && { temperature: item.options.temperature, }), }, } return JSON.stringify(request) }) const inputKey = `${options.s3InputPrefix || 'bedrock-batch/input'}/${options.jobName}.jsonl` const outputPrefix = `${options.s3OutputPrefix || 'bedrock-batch/output'}/${options.jobName}/` const s3Url = `https://${config.bucket}.s3.${config.region}.amazonaws.com/${inputKey}` const content = jsonlLines.join('\n') const s3Response = await fetch(s3Url, { method: 'PUT', headers: await signRequest('PUT', s3Url, content, config, 's3'), body: content, }) if (!s3Response.ok) { throw new Error(`Failed to upload to S3: ${s3Response.status}`) } const jobUrl = `https://bedrock.${config.region}.amazonaws.com/model-invocation-job` const jobBody = JSON.stringify({ jobName: options.jobName, modelId: model, roleArn: options.roleArn, inputDataConfig: { s3InputDataConfig: { s3Uri: `s3://${config.bucket}/${inputKey}` }, }, outputDataConfig: { s3OutputDataConfig: { s3Uri: `s3://${config.bucket}/${outputPrefix}` }, }, }) const jobResponse = await fetch(jobUrl, { method: 'POST', headers: await signRequest('POST', jobUrl, jobBody, config, 'bedrock'), body: jobBody, }) if (!jobResponse.ok) { const error = await jobResponse.text() throw new Error(`Failed to create Bedrock batch job: ${jobResponse.status} ${error}`) } const jobData = (await jobResponse.json()) as { jobArn: string } return jobData } // ============================================================================ // Bedrock flex adapter (FlexAdapter port) // ============================================================================ const bedrockFlexAdapter: FlexAdapter = { async submitFlex(items: BatchItem[], options: { model?: string }): Promise<BatchResult[]> { const config = getConfig() const model = options.model || 'anthropic.claude-3-sonnet-20240229-v1:0' return processConcurrently(items, (item) => processBedrockItem(item, config, model), { concurrency: 8, delayBetweenWaves: 500, }) }, } // ============================================================================ // Register adapters // ============================================================================ registerBatchAdapter('bedrock', bedrockAdapter) registerFlexAdapter('bedrock', bedrockFlexAdapter) export { bedrockAdapter, bedrockFlexAdapter }