UNPKG

genkitx-aws-bedrock

Version:
531 lines 18.9 kB
/** * Copyright 2026 Xavier Portilla Edo * Copyright 2026 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { UserFacingError } from "genkit"; import { getCallableJSON, getHttpStatus, } from "genkit/context"; /** * Builds CORS headers based on options */ function buildCorsHeaders(corsOptions, requestOrigin) { if (corsOptions === false) { return {}; } const opts = corsOptions === true || corsOptions === undefined ? {} : corsOptions; const headers = { "Content-Type": "application/json", }; // Handle origin const origin = opts.origin ?? "*"; if (Array.isArray(origin)) { // Check if request origin is in allowed list if (requestOrigin && origin.includes(requestOrigin)) { headers["Access-Control-Allow-Origin"] = requestOrigin; } // If request origin is not in the allowlist, don't set the header } else { headers["Access-Control-Allow-Origin"] = origin; } // Handle methods const methods = opts.methods ?? ["POST", "OPTIONS"]; headers["Access-Control-Allow-Methods"] = methods.join(", "); // Handle allowed headers const allowedHeaders = opts.allowedHeaders ?? [ "Content-Type", "Authorization", ]; headers["Access-Control-Allow-Headers"] = allowedHeaders.join(", "); // Handle exposed headers if (opts.exposedHeaders && opts.exposedHeaders.length > 0) { headers["Access-Control-Expose-Headers"] = opts.exposedHeaders.join(", "); } // Handle credentials if (opts.credentials) { headers["Access-Control-Allow-Credentials"] = "true"; } // Handle max age const maxAge = opts.maxAge ?? 86400; headers["Access-Control-Max-Age"] = String(maxAge); return headers; } /** * Parses the request body from an API Gateway event. * Supports the Genkit callable protocol format where input is wrapped in { data: ... } * as well as direct input format for convenience. */ function parseRequestBody(event) { if (!event.body) { return {}; } const body = event.isBase64Encoded ? Buffer.from(event.body, "base64").toString("utf-8") : event.body; try { const parsed = JSON.parse(body); // Support callable protocol: { data: <input> } if (parsed && typeof parsed === "object" && "data" in parsed) { return parsed.data; } return parsed; } catch { throw new UserFacingError("INVALID_ARGUMENT", "Invalid JSON in request body"); } } /** * Gets the request origin from headers */ function getRequestOrigin(event) { const headers = event.headers || {}; return headers["origin"] || headers["Origin"]; } /** * Converts Lambda event headers to lowercase record (as required by RequestData) */ function normalizeHeaders(headers) { const result = {}; if (!headers) return result; for (const [key, value] of Object.entries(headers)) { if (value !== undefined) { result[key.toLowerCase()] = value; } } return result; } /** * Converts Lambda event to Genkit RequestData format */ function toRequestData(event, input) { // Determine HTTP method const method = "httpMethod" in event ? event.httpMethod : event.requestContext?.http?.method || "POST"; return { method: method, headers: normalizeHeaders(event.headers), input, }; } /** * Implementation of onCallGenkit */ export function onCallGenkit(optsOrFlow, flowArg) { let opts; let flow; if (arguments.length === 1) { opts = {}; flow = optsOrFlow; } else { opts = optsOrFlow; flow = flowArg; } const flowName = flow.__action?.name || "unknown"; const handler = async (event, lambdaContext) => { const requestOrigin = getRequestOrigin(event); const corsHeaders = buildCorsHeaders(opts.cors, requestOrigin); // Handle OPTIONS preflight request if (event.httpMethod === "OPTIONS") { return { statusCode: 204, headers: corsHeaders, body: "", }; } // Debug logging if (opts.debug) { console.log(`[${flowName}] Event:`, JSON.stringify(event, null, 2)); console.log(`[${flowName}] Context:`, JSON.stringify(lambdaContext, null, 2)); } try { // Parse request body const input = parseRequestBody(event); // Build Lambda-specific context const lambdaActionContext = { lambda: { event: { requestContext: event.requestContext, headers: event.headers, queryStringParameters: event.queryStringParameters, pathParameters: event.pathParameters, }, context: { functionName: lambdaContext.functionName, functionVersion: lambdaContext.functionVersion, invokedFunctionArn: lambdaContext.invokedFunctionArn, memoryLimitInMB: lambdaContext.memoryLimitInMB, awsRequestId: lambdaContext.awsRequestId, }, }, }; // Run context provider if provided let actionContext = lambdaActionContext; if (opts.contextProvider) { const requestData = toRequestData(event, input); const providerContext = await opts.contextProvider(requestData); // Merge provider context with Lambda context actionContext = { ...lambdaActionContext, ...providerContext, }; } if (opts.debug) { console.log(`[${flowName}] Running flow with input:`, input); } // Execute the flow with context const runResult = await flow.run(input, { context: actionContext }); const result = runResult.result; if (opts.debug) { console.log(`[${flowName}] Flow completed successfully`); } // Return success response (callable protocol) return { statusCode: 200, headers: corsHeaders, body: JSON.stringify({ result, }), }; } catch (error) { console.error(`[${flowName}] Error:`, error); // Use custom error handler if provided if (opts.onError) { const customError = await opts.onError(error instanceof Error ? error : new Error(String(error))); return { statusCode: customError.statusCode, headers: corsHeaders, body: JSON.stringify({ error: { status: "INTERNAL", message: customError.message, }, }), }; } // Use Genkit's callable error format (same as express handler) return { statusCode: getHttpStatus(error), headers: corsHeaders, body: JSON.stringify(getCallableJSON(error)), }; } }; // Attach additional properties to the handler const callableFunction = handler; callableFunction.flow = flow; callableFunction.flowName = flowName; // If streaming mode, return the streaming handler directly if (opts.streaming) { return awslambda.streamifyResponse(async (event, responseStream, lambdaCtx) => { const requestOrigin = getRequestOrigin(event); const corsHeaders = buildCorsHeaders(opts.cors, requestOrigin); // Handle OPTIONS preflight const method = event.requestContext?.http?.method || "POST"; if (method === "OPTIONS") { const httpStream = awslambda.HttpResponseStream.from(responseStream, { statusCode: 204, headers: corsHeaders, }); httpStream.end(); return; } if (opts.debug) { console.log(`[${flowName}] Stream event:`, JSON.stringify(event, null, 2)); } try { const input = parseRequestBody(event); // Build context const lambdaActionContext = { lambda: { event: { requestContext: event.requestContext, headers: event.headers, queryStringParameters: event.queryStringParameters, pathParameters: event.pathParameters, }, context: { functionName: lambdaCtx.functionName, functionVersion: lambdaCtx.functionVersion, invokedFunctionArn: lambdaCtx.invokedFunctionArn, memoryLimitInMB: lambdaCtx.memoryLimitInMB, awsRequestId: lambdaCtx.awsRequestId, }, }, }; let actionContext = lambdaActionContext; if (opts.contextProvider) { const requestData = toRequestData(event, input); const providerContext = await opts.contextProvider(requestData); actionContext = { ...lambdaActionContext, ...providerContext }; } // Check if client wants SSE streaming const acceptHeader = event.headers?.["accept"] || event.headers?.["Accept"] || ""; const clientWantsStreaming = acceptHeader.includes("text/event-stream"); if (clientWantsStreaming) { // Real streaming: write SSE events incrementally const httpStream = awslambda.HttpResponseStream.from(responseStream, { statusCode: 200, headers: { ...corsHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", }, }); const { stream, output } = flow.stream(input, { context: actionContext, }); for await (const chunk of stream) { httpStream.write(`data: ${JSON.stringify({ message: chunk })}\n\n`); } const result = (await output); httpStream.write(`data: ${JSON.stringify({ result })}\n\n`); httpStream.end(); if (opts.debug) { console.log(`[${flowName}] Streaming flow completed successfully`); } } else { // Non-streaming: buffered JSON response const runResult = await flow.run(input, { context: actionContext, }); const result = runResult.result; const httpStream = awslambda.HttpResponseStream.from(responseStream, { statusCode: 200, headers: corsHeaders, }); httpStream.write(JSON.stringify({ result })); httpStream.end(); } } catch (error) { console.error(`[${flowName}] Stream error:`, error); let statusCode = getHttpStatus(error); let body; if (opts.onError) { const customError = await opts.onError(error instanceof Error ? error : new Error(String(error))); statusCode = customError.statusCode; body = JSON.stringify({ error: { status: "INTERNAL", message: customError.message, }, }); } else { body = JSON.stringify(getCallableJSON(error)); } const httpStream = awslambda.HttpResponseStream.from(responseStream, { statusCode, headers: corsHeaders, }); httpStream.write(body); httpStream.end(); } }); } return callableFunction; } /** * Creates a context provider that requires an API key in a specific header. * * @example * ```typescript * // Require API key to match a specific value * export const handler = onCallGenkit( * { contextProvider: requireApiKey('X-API-Key', process.env.API_KEY!) }, * myFlow * ); * * // Or with a custom validation function * export const handler = onCallGenkit( * { * contextProvider: requireApiKey('X-API-Key', async (key) => { * const valid = await validateApiKey(key); * if (!valid) { * throw new UserFacingError('PERMISSION_DENIED', 'Invalid API key'); * } * }) * }, * myFlow * ); * ``` */ export function requireApiKey(headerName, expectedValueOrValidator) { const lowerHeaderName = headerName.toLowerCase(); return async (request) => { const apiKey = request.headers[lowerHeaderName]; if (!apiKey) { throw new UserFacingError("UNAUTHENTICATED", `Missing required header: ${headerName}`); } if (typeof expectedValueOrValidator === "string") { if (apiKey !== expectedValueOrValidator) { throw new UserFacingError("PERMISSION_DENIED", "Invalid API key"); } } else { await expectedValueOrValidator(apiKey); } return { auth: { apiKey }, }; }; } /** * Creates a context provider that requires Bearer token authentication. * * @example * ```typescript * // With custom token validation * export const handler = onCallGenkit( * { * contextProvider: requireBearerToken(async (token) => { * const user = await verifyJWT(token); * return { auth: { user } }; * }) * }, * myFlow * ); * ``` */ export function requireBearerToken(validateToken) { return async (request) => { const authHeader = request.headers["authorization"]; if (!authHeader) { throw new UserFacingError("UNAUTHENTICATED", "Missing Authorization header"); } const match = authHeader.match(/^Bearer\s+(.+)$/i); if (!match) { throw new UserFacingError("UNAUTHENTICATED", "Invalid Authorization header format. Expected: Bearer <token>"); } const token = match[1]; return await validateToken(token); }; } /** * Creates a context provider that requires a specific header to be present. * * @example * ```typescript * // Require header to exist * export const handler = onCallGenkit( * { contextProvider: requireHeader('X-Request-ID') }, * myFlow * ); * * // Require header to have specific value * export const handler = onCallGenkit( * { contextProvider: requireHeader('X-API-Version', '2.0') }, * myFlow * ); * ``` */ export function requireHeader(headerName, expectedValue) { const lowerHeaderName = headerName.toLowerCase(); return async (request) => { const value = request.headers[lowerHeaderName]; if (!value) { throw new UserFacingError("UNAUTHENTICATED", `Missing required header: ${headerName}`); } if (expectedValue !== undefined && value !== expectedValue) { throw new UserFacingError("PERMISSION_DENIED", `Invalid value for header: ${headerName}`); } return {}; }; } /** * Creates a context provider that always allows requests (no authentication). * Useful for public endpoints. * * @example * ```typescript * export const handler = onCallGenkit( * { contextProvider: allowAll() }, * myPublicFlow * ); * ``` */ export function allowAll() { return async () => ({}); } /** * Combines multiple context providers. All providers must succeed. * The returned context is a merge of all provider contexts. * * @example * ```typescript * export const handler = onCallGenkit( * { * contextProvider: allOf( * requireHeader('X-Request-ID'), * requireApiKey('X-API-Key', process.env.API_KEY!) * ) * }, * myFlow * ); * ``` */ export function allOf(...providers) { return async (request) => { let mergedContext = {}; for (const provider of providers) { const context = await provider(request); mergedContext = { ...mergedContext, ...context }; } return mergedContext; }; } /** * Tries context providers in order, returning the first one that succeeds. * If all providers fail, throws the error from the last provider. * * @example * ```typescript * // Accept either API key or Bearer token * export const handler = onCallGenkit( * { * contextProvider: anyOf( * requireApiKey('X-API-Key', process.env.API_KEY!), * requireBearerToken(async (token) => { * const user = await verifyJWT(token); * return { auth: { user } }; * }) * ) * }, * myFlow * ); * ``` */ export function anyOf(...providers) { return async (request) => { let lastError; for (const provider of providers) { try { const context = await provider(request); return context; } catch (error) { lastError = error instanceof Error ? error : new Error(String(error)); } } throw lastError || new UserFacingError("UNAUTHENTICATED", "Unauthorized"); }; } export default onCallGenkit; //# sourceMappingURL=aws_lambda.js.map