UNPKG

ai-functions

Version:

Core AI primitives for building intelligent applications

403 lines 16.7 kB
/** * AWS Bedrock Batch Inference Adapter * * Bedrock has a true batch inference API (S3-driven) and a runtime invoke API. * The "batch" adapter here uses concurrent runtime invocations as a fallback * (no S3 setup required); `createBedrockBatchJob` is exported separately for * callers who want to drive the real S3-based batch flow directly. * * @see https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference.html * * @packageDocumentation */ import { getLogger } from '../logger.js'; import { LocalJobStore, processConcurrently, registerBatchAdapter, registerFlexAdapter, tryParseJson, } from './provider.js'; // ============================================================================ // AWS configuration // ============================================================================ let awsRegion; let awsAccessKeyId; let awsSecretAccessKey; let awsSessionToken; let s3Bucket; let roleArn; let gatewayUrl; let gatewayToken; /** Configure AWS credentials and settings. */ export function configureAWSBedrock(options) { if (options.region) awsRegion = options.region; if (options.accessKeyId) awsAccessKeyId = options.accessKeyId; if (options.secretAccessKey) awsSecretAccessKey = options.secretAccessKey; if (options.sessionToken) awsSessionToken = options.sessionToken; if (options.s3Bucket) s3Bucket = options.s3Bucket; if (options.roleArn) roleArn = options.roleArn; if (options.gatewayUrl) gatewayUrl = options.gatewayUrl; if (options.gatewayToken) gatewayToken = options.gatewayToken; } function getConfig() { const region = awsRegion || process.env['AWS_REGION'] || process.env['AWS_DEFAULT_REGION'] || 'us-east-1'; const accessKeyId = awsAccessKeyId || process.env['AWS_ACCESS_KEY_ID']; const secretAccessKey = awsSecretAccessKey || process.env['AWS_SECRET_ACCESS_KEY']; const sessionToken = awsSessionToken || process.env['AWS_SESSION_TOKEN']; const bucket = s3Bucket || process.env['BEDROCK_BATCH_S3_BUCKET']; const role = roleArn || process.env['BEDROCK_BATCH_ROLE_ARN']; const gwUrl = gatewayUrl || process.env['AI_GATEWAY_URL']; const gwToken = gatewayToken || process.env['AI_GATEWAY_TOKEN']; if (gwUrl && gwToken) { return { region, accessKeyId: accessKeyId || '', secretAccessKey: secretAccessKey || '', sessionToken, bucket: bucket || '', role, gatewayUrl: gwUrl, gatewayToken: gwToken, }; } if (!accessKeyId || !secretAccessKey) { throw new Error('AWS credentials not configured. Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY, or use AI_GATEWAY_URL and AI_GATEWAY_TOKEN'); } if (!bucket) { throw new Error('S3 bucket for Bedrock batch not configured. Set BEDROCK_BATCH_S3_BUCKET'); } return { region, accessKeyId, secretAccessKey, sessionToken, bucket, role, gatewayUrl: undefined, gatewayToken: undefined, }; } // ============================================================================ // AWS SigV4 (delegated to optional @smithy/signature-v4 if available) // ============================================================================ async function signRequest(method, url, body, config, service) { const headers = new Headers({ 'Content-Type': 'application/json', 'X-Amz-Date': new Date().toISOString().replace(/[:-]|\.\d{3}/g, ''), }); if (config.sessionToken) { headers.set('X-Amz-Security-Token', config.sessionToken); } try { // Optional dependency — present in production, absent in dev/test. // @ts-expect-error - Optional dependency const signatureV4Module = await import('@smithy/signature-v4'); // @ts-expect-error - Optional dependency const sha256Module = await import('@aws-crypto/sha256-js'); const signer = new signatureV4Module.SignatureV4({ service, region: config.region, credentials: { accessKeyId: config.accessKeyId, secretAccessKey: config.secretAccessKey, sessionToken: config.sessionToken, }, sha256: sha256Module.Sha256, }); const signedRequest = await signer.sign({ method, headers: Object.fromEntries(headers.entries()), hostname: new URL(url).hostname, path: new URL(url).pathname, body, }); return new Headers(signedRequest.headers); } catch { getLogger().warn('AWS SDK not available for request signing. Install @smithy/signature-v4 and @aws-crypto/sha256-js'); return headers; } } // ============================================================================ // Local job tracking // ============================================================================ const jobs = new LocalJobStore('bedrock_batch'); // ============================================================================ // Bedrock batch adapter (BatchProvider port) // ============================================================================ const bedrockAdapter = { async submit(items, options) { const config = getConfig(); const model = options.model || 'anthropic.claude-3-sonnet-20240229-v1:0'; const { id, state } = jobs.create(items, options); // Drive the job state machine in the background. const completion = (async () => { state.status = 'in_progress'; const results = await processConcurrently(items, (item) => processBedrockItem(item, config, model), { concurrency: 5, // Bedrock has stricter rate limits. delayBetweenWaves: 1000, onWaveComplete: (partial) => { state.results = partial; }, }); state.results = results; state.status = results.every((r) => r.status === 'completed') ? 'completed' : 'failed'; state.completedAt = new Date(); return results; })(); const job = { id, provider: 'bedrock', status: 'pending', totalItems: items.length, completedItems: 0, failedItems: 0, createdAt: state.createdAt, ...(options.webhookUrl !== undefined && { webhookUrl: options.webhookUrl }), }; return { job, completion }; }, async getStatus(batchId) { return jobs.snapshot(batchId, 'bedrock'); }, async cancel(batchId) { if (!jobs.has(batchId)) return; const state = jobs.get(batchId); state.status = 'cancelled'; const jobArn = state.meta?.['jobArn']; if (jobArn) { const config = getConfig(); const url = `https://bedrock.${config.region}.amazonaws.com/model-invocation-job/${encodeURIComponent(jobArn)}/stop`; try { await fetch(url, { method: 'POST', headers: await signRequest('POST', url, '', config, 'bedrock'), }); } catch (error) { getLogger().warn('Failed to cancel Bedrock job:', error); } } }, async getResults(batchId) { return jobs.get(batchId).results; }, async waitForCompletion(batchId, pollInterval = 5000) { return jobs.waitForCompletion(batchId, pollInterval); }, }; // ============================================================================ // Per-item processing // ============================================================================ async function processBedrockItem(item, config, model) { if (config.gatewayUrl && config.gatewayToken) { return processBedrockItemViaGateway(item, config, model); } const url = `https://bedrock-runtime.${config.region}.amazonaws.com/model/${encodeURIComponent(model)}/invoke`; const body = buildBedrockRequestBody(item, model); const bodyStr = JSON.stringify(body); const headers = await signRequest('POST', url, bodyStr, config, 'bedrock'); const response = await fetch(url, { method: 'POST', headers, body: bodyStr }); if (!response.ok) { const error = await response.text(); throw new Error(`Bedrock API error: ${response.status} ${error}`); } return parseBedrockResponse(item, await response.json()); } /** * Process a Bedrock item via Cloudflare AI Gateway. * * Note: AI Gateway routes the request but doesn't handle authentication — * Bedrock still requires AWS SigV4 signing. * @see https://developers.cloudflare.com/ai-gateway/usage/providers/bedrock/ */ async function processBedrockItemViaGateway(item, config, model) { const url = `${config.gatewayUrl}/aws-bedrock/bedrock-runtime/${config.region}/model/${encodeURIComponent(model)}/invoke`; const body = { anthropic_version: 'bedrock-2023-05-31', max_tokens: item.options?.maxTokens || 4096, messages: [{ role: 'user', content: item.prompt }], ...(item.options?.system !== undefined && { system: item.options.system }), ...(item.options?.temperature !== undefined && { temperature: item.options.temperature }), }; const bodyStr = JSON.stringify(body); if (!config.accessKeyId || !config.secretAccessKey) { throw new Error('Bedrock via AI Gateway still requires AWS credentials for SigV4 signing. ' + 'Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY.'); } const headers = await signRequest('POST', url, bodyStr, config, 'bedrock'); headers.set('cf-aig-authorization', `Bearer ${config.gatewayToken}`); const response = await fetch(url, { method: 'POST', headers, body: bodyStr }); if (!response.ok) { const error = await response.text(); throw new Error(`Bedrock via Gateway error: ${response.status} ${error}`); } return parseBedrockResponse(item, await response.json()); } /** Build the Bedrock invoke body for the model family. */ function buildBedrockRequestBody(item, model) { if (model.includes('anthropic')) { return { anthropic_version: 'bedrock-2023-05-31', max_tokens: item.options?.maxTokens || 4096, messages: [{ role: 'user', content: item.prompt }], ...(item.options?.system !== undefined && { system: item.options.system }), ...(item.options?.temperature !== undefined && { temperature: item.options.temperature }), }; } if (model.includes('amazon')) { return { inputText: item.prompt, textGenerationConfig: { maxTokenCount: item.options?.maxTokens || 4096, temperature: item.options?.temperature || 0.7, }, }; } if (model.includes('meta')) { return { prompt: item.prompt, max_gen_len: item.options?.maxTokens || 4096, temperature: item.options?.temperature || 0.7, }; } if (model.includes('mistral')) { return { prompt: `<s>[INST] ${item.prompt} [/INST]`, max_tokens: item.options?.maxTokens || 4096, temperature: item.options?.temperature || 0.7, }; } // Default: Claude-style. return { anthropic_version: 'bedrock-2023-05-31', max_tokens: item.options?.maxTokens || 4096, messages: [{ role: 'user', content: item.prompt }], ...(item.options?.temperature !== undefined && { temperature: item.options.temperature }), }; } /** Parse a Bedrock invoke response across model families. */ function parseBedrockResponse(item, raw) { const data = raw; let content; let usage; if (data.content) { content = data.content.find((c) => c.type === 'text')?.text; if (data.usage) { usage = { promptTokens: data.usage.input_tokens, completionTokens: data.usage.output_tokens, totalTokens: data.usage.input_tokens + data.usage.output_tokens, }; } } else if (data.results?.[0]) { content = data.results[0].outputText; usage = { promptTokens: 0, completionTokens: data.results[0].tokenCount || 0, totalTokens: data.results[0].tokenCount || 0, }; } else if (data.generation) { content = data.generation; if (data.generation_token_count !== undefined) { usage = { promptTokens: data.prompt_token_count || 0, completionTokens: data.generation_token_count, totalTokens: (data.prompt_token_count || 0) + data.generation_token_count, }; } } return { id: item.id, customId: item.id, status: 'completed', result: tryParseJson(content, !!item.schema), ...(usage && { usage }), }; } // ============================================================================ // True S3-based batch inference (separate from the BatchProvider adapter) // ============================================================================ /** * Create and submit a true Bedrock batch inference job. * Requires S3 bucket access and proper IAM setup. */ export async function createBedrockBatchJob(items, model, options) { const config = getConfig(); const jsonlLines = items.map((item) => { const request = { recordId: item.id, modelInput: { anthropic_version: 'bedrock-2023-05-31', max_tokens: item.options?.maxTokens || 4096, messages: [{ role: 'user', content: item.prompt }], ...(item.options?.system !== undefined && { system: item.options.system }), ...(item.options?.temperature !== undefined && { temperature: item.options.temperature, }), }, }; return JSON.stringify(request); }); const inputKey = `${options.s3InputPrefix || 'bedrock-batch/input'}/${options.jobName}.jsonl`; const outputPrefix = `${options.s3OutputPrefix || 'bedrock-batch/output'}/${options.jobName}/`; const s3Url = `https://${config.bucket}.s3.${config.region}.amazonaws.com/${inputKey}`; const content = jsonlLines.join('\n'); const s3Response = await fetch(s3Url, { method: 'PUT', headers: await signRequest('PUT', s3Url, content, config, 's3'), body: content, }); if (!s3Response.ok) { throw new Error(`Failed to upload to S3: ${s3Response.status}`); } const jobUrl = `https://bedrock.${config.region}.amazonaws.com/model-invocation-job`; const jobBody = JSON.stringify({ jobName: options.jobName, modelId: model, roleArn: options.roleArn, inputDataConfig: { s3InputDataConfig: { s3Uri: `s3://${config.bucket}/${inputKey}` }, }, outputDataConfig: { s3OutputDataConfig: { s3Uri: `s3://${config.bucket}/${outputPrefix}` }, }, }); const jobResponse = await fetch(jobUrl, { method: 'POST', headers: await signRequest('POST', jobUrl, jobBody, config, 'bedrock'), body: jobBody, }); if (!jobResponse.ok) { const error = await jobResponse.text(); throw new Error(`Failed to create Bedrock batch job: ${jobResponse.status} ${error}`); } const jobData = (await jobResponse.json()); return jobData; } // ============================================================================ // Bedrock flex adapter (FlexAdapter port) // ============================================================================ const bedrockFlexAdapter = { async submitFlex(items, options) { const config = getConfig(); const model = options.model || 'anthropic.claude-3-sonnet-20240229-v1:0'; return processConcurrently(items, (item) => processBedrockItem(item, config, model), { concurrency: 8, delayBetweenWaves: 500, }); }, }; // ============================================================================ // Register adapters // ============================================================================ registerBatchAdapter('bedrock', bedrockAdapter); registerFlexAdapter('bedrock', bedrockFlexAdapter); export { bedrockAdapter, bedrockFlexAdapter }; //# sourceMappingURL=bedrock.js.map