genkitx-aws-bedrock
Version:
Genkit AI framework plugin for AWS Bedrock APIs.
531 lines • 18.9 kB
JavaScript
/**
* Copyright 2026 Xavier Portilla Edo
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { UserFacingError } from "genkit";
import { getCallableJSON, getHttpStatus, } from "genkit/context";
/**
* Builds CORS headers based on options
*/
function buildCorsHeaders(corsOptions, requestOrigin) {
if (corsOptions === false) {
return {};
}
const opts = corsOptions === true || corsOptions === undefined ? {} : corsOptions;
const headers = {
"Content-Type": "application/json",
};
// Handle origin
const origin = opts.origin ?? "*";
if (Array.isArray(origin)) {
// Check if request origin is in allowed list
if (requestOrigin && origin.includes(requestOrigin)) {
headers["Access-Control-Allow-Origin"] = requestOrigin;
}
// If request origin is not in the allowlist, don't set the header
}
else {
headers["Access-Control-Allow-Origin"] = origin;
}
// Handle methods
const methods = opts.methods ?? ["POST", "OPTIONS"];
headers["Access-Control-Allow-Methods"] = methods.join(", ");
// Handle allowed headers
const allowedHeaders = opts.allowedHeaders ?? [
"Content-Type",
"Authorization",
];
headers["Access-Control-Allow-Headers"] = allowedHeaders.join(", ");
// Handle exposed headers
if (opts.exposedHeaders && opts.exposedHeaders.length > 0) {
headers["Access-Control-Expose-Headers"] = opts.exposedHeaders.join(", ");
}
// Handle credentials
if (opts.credentials) {
headers["Access-Control-Allow-Credentials"] = "true";
}
// Handle max age
const maxAge = opts.maxAge ?? 86400;
headers["Access-Control-Max-Age"] = String(maxAge);
return headers;
}
/**
* Parses the request body from an API Gateway event.
* Supports the Genkit callable protocol format where input is wrapped in { data: ... }
* as well as direct input format for convenience.
*/
function parseRequestBody(event) {
if (!event.body) {
return {};
}
const body = event.isBase64Encoded
? Buffer.from(event.body, "base64").toString("utf-8")
: event.body;
try {
const parsed = JSON.parse(body);
// Support callable protocol: { data: <input> }
if (parsed && typeof parsed === "object" && "data" in parsed) {
return parsed.data;
}
return parsed;
}
catch {
throw new UserFacingError("INVALID_ARGUMENT", "Invalid JSON in request body");
}
}
/**
* Gets the request origin from headers
*/
function getRequestOrigin(event) {
const headers = event.headers || {};
return headers["origin"] || headers["Origin"];
}
/**
* Converts Lambda event headers to lowercase record (as required by RequestData)
*/
function normalizeHeaders(headers) {
const result = {};
if (!headers)
return result;
for (const [key, value] of Object.entries(headers)) {
if (value !== undefined) {
result[key.toLowerCase()] = value;
}
}
return result;
}
/**
* Converts Lambda event to Genkit RequestData format
*/
function toRequestData(event, input) {
// Determine HTTP method
const method = "httpMethod" in event
? event.httpMethod
: event.requestContext?.http?.method || "POST";
return {
method: method,
headers: normalizeHeaders(event.headers),
input,
};
}
/**
* Implementation of onCallGenkit
*/
export function onCallGenkit(optsOrFlow, flowArg) {
let opts;
let flow;
if (arguments.length === 1) {
opts = {};
flow = optsOrFlow;
}
else {
opts = optsOrFlow;
flow = flowArg;
}
const flowName = flow.__action?.name || "unknown";
const handler = async (event, lambdaContext) => {
const requestOrigin = getRequestOrigin(event);
const corsHeaders = buildCorsHeaders(opts.cors, requestOrigin);
// Handle OPTIONS preflight request
if (event.httpMethod === "OPTIONS") {
return {
statusCode: 204,
headers: corsHeaders,
body: "",
};
}
// Debug logging
if (opts.debug) {
console.log(`[${flowName}] Event:`, JSON.stringify(event, null, 2));
console.log(`[${flowName}] Context:`, JSON.stringify(lambdaContext, null, 2));
}
try {
// Parse request body
const input = parseRequestBody(event);
// Build Lambda-specific context
const lambdaActionContext = {
lambda: {
event: {
requestContext: event.requestContext,
headers: event.headers,
queryStringParameters: event.queryStringParameters,
pathParameters: event.pathParameters,
},
context: {
functionName: lambdaContext.functionName,
functionVersion: lambdaContext.functionVersion,
invokedFunctionArn: lambdaContext.invokedFunctionArn,
memoryLimitInMB: lambdaContext.memoryLimitInMB,
awsRequestId: lambdaContext.awsRequestId,
},
},
};
// Run context provider if provided
let actionContext = lambdaActionContext;
if (opts.contextProvider) {
const requestData = toRequestData(event, input);
const providerContext = await opts.contextProvider(requestData);
// Merge provider context with Lambda context
actionContext = {
...lambdaActionContext,
...providerContext,
};
}
if (opts.debug) {
console.log(`[${flowName}] Running flow with input:`, input);
}
// Execute the flow with context
const runResult = await flow.run(input, { context: actionContext });
const result = runResult.result;
if (opts.debug) {
console.log(`[${flowName}] Flow completed successfully`);
}
// Return success response (callable protocol)
return {
statusCode: 200,
headers: corsHeaders,
body: JSON.stringify({
result,
}),
};
}
catch (error) {
console.error(`[${flowName}] Error:`, error);
// Use custom error handler if provided
if (opts.onError) {
const customError = await opts.onError(error instanceof Error ? error : new Error(String(error)));
return {
statusCode: customError.statusCode,
headers: corsHeaders,
body: JSON.stringify({
error: {
status: "INTERNAL",
message: customError.message,
},
}),
};
}
// Use Genkit's callable error format (same as express handler)
return {
statusCode: getHttpStatus(error),
headers: corsHeaders,
body: JSON.stringify(getCallableJSON(error)),
};
}
};
// Attach additional properties to the handler
const callableFunction = handler;
callableFunction.flow = flow;
callableFunction.flowName = flowName;
// If streaming mode, return the streaming handler directly
if (opts.streaming) {
return awslambda.streamifyResponse(async (event, responseStream, lambdaCtx) => {
const requestOrigin = getRequestOrigin(event);
const corsHeaders = buildCorsHeaders(opts.cors, requestOrigin);
// Handle OPTIONS preflight
const method = event.requestContext?.http?.method || "POST";
if (method === "OPTIONS") {
const httpStream = awslambda.HttpResponseStream.from(responseStream, {
statusCode: 204,
headers: corsHeaders,
});
httpStream.end();
return;
}
if (opts.debug) {
console.log(`[${flowName}] Stream event:`, JSON.stringify(event, null, 2));
}
try {
const input = parseRequestBody(event);
// Build context
const lambdaActionContext = {
lambda: {
event: {
requestContext: event.requestContext,
headers: event.headers,
queryStringParameters: event.queryStringParameters,
pathParameters: event.pathParameters,
},
context: {
functionName: lambdaCtx.functionName,
functionVersion: lambdaCtx.functionVersion,
invokedFunctionArn: lambdaCtx.invokedFunctionArn,
memoryLimitInMB: lambdaCtx.memoryLimitInMB,
awsRequestId: lambdaCtx.awsRequestId,
},
},
};
let actionContext = lambdaActionContext;
if (opts.contextProvider) {
const requestData = toRequestData(event, input);
const providerContext = await opts.contextProvider(requestData);
actionContext = { ...lambdaActionContext, ...providerContext };
}
// Check if client wants SSE streaming
const acceptHeader = event.headers?.["accept"] || event.headers?.["Accept"] || "";
const clientWantsStreaming = acceptHeader.includes("text/event-stream");
if (clientWantsStreaming) {
// Real streaming: write SSE events incrementally
const httpStream = awslambda.HttpResponseStream.from(responseStream, {
statusCode: 200,
headers: {
...corsHeaders,
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
},
});
const { stream, output } = flow.stream(input, {
context: actionContext,
});
for await (const chunk of stream) {
httpStream.write(`data: ${JSON.stringify({ message: chunk })}\n\n`);
}
const result = (await output);
httpStream.write(`data: ${JSON.stringify({ result })}\n\n`);
httpStream.end();
if (opts.debug) {
console.log(`[${flowName}] Streaming flow completed successfully`);
}
}
else {
// Non-streaming: buffered JSON response
const runResult = await flow.run(input, {
context: actionContext,
});
const result = runResult.result;
const httpStream = awslambda.HttpResponseStream.from(responseStream, {
statusCode: 200,
headers: corsHeaders,
});
httpStream.write(JSON.stringify({ result }));
httpStream.end();
}
}
catch (error) {
console.error(`[${flowName}] Stream error:`, error);
let statusCode = getHttpStatus(error);
let body;
if (opts.onError) {
const customError = await opts.onError(error instanceof Error ? error : new Error(String(error)));
statusCode = customError.statusCode;
body = JSON.stringify({
error: {
status: "INTERNAL",
message: customError.message,
},
});
}
else {
body = JSON.stringify(getCallableJSON(error));
}
const httpStream = awslambda.HttpResponseStream.from(responseStream, {
statusCode,
headers: corsHeaders,
});
httpStream.write(body);
httpStream.end();
}
});
}
return callableFunction;
}
/**
* Creates a context provider that requires an API key in a specific header.
*
* @example
* ```typescript
* // Require API key to match a specific value
* export const handler = onCallGenkit(
* { contextProvider: requireApiKey('X-API-Key', process.env.API_KEY!) },
* myFlow
* );
*
* // Or with a custom validation function
* export const handler = onCallGenkit(
* {
* contextProvider: requireApiKey('X-API-Key', async (key) => {
* const valid = await validateApiKey(key);
* if (!valid) {
* throw new UserFacingError('PERMISSION_DENIED', 'Invalid API key');
* }
* })
* },
* myFlow
* );
* ```
*/
export function requireApiKey(headerName, expectedValueOrValidator) {
const lowerHeaderName = headerName.toLowerCase();
return async (request) => {
const apiKey = request.headers[lowerHeaderName];
if (!apiKey) {
throw new UserFacingError("UNAUTHENTICATED", `Missing required header: ${headerName}`);
}
if (typeof expectedValueOrValidator === "string") {
if (apiKey !== expectedValueOrValidator) {
throw new UserFacingError("PERMISSION_DENIED", "Invalid API key");
}
}
else {
await expectedValueOrValidator(apiKey);
}
return {
auth: { apiKey },
};
};
}
/**
* Creates a context provider that requires Bearer token authentication.
*
* @example
* ```typescript
* // With custom token validation
* export const handler = onCallGenkit(
* {
* contextProvider: requireBearerToken(async (token) => {
* const user = await verifyJWT(token);
* return { auth: { user } };
* })
* },
* myFlow
* );
* ```
*/
export function requireBearerToken(validateToken) {
return async (request) => {
const authHeader = request.headers["authorization"];
if (!authHeader) {
throw new UserFacingError("UNAUTHENTICATED", "Missing Authorization header");
}
const match = authHeader.match(/^Bearer\s+(.+)$/i);
if (!match) {
throw new UserFacingError("UNAUTHENTICATED", "Invalid Authorization header format. Expected: Bearer <token>");
}
const token = match[1];
return await validateToken(token);
};
}
/**
* Creates a context provider that requires a specific header to be present.
*
* @example
* ```typescript
* // Require header to exist
* export const handler = onCallGenkit(
* { contextProvider: requireHeader('X-Request-ID') },
* myFlow
* );
*
* // Require header to have specific value
* export const handler = onCallGenkit(
* { contextProvider: requireHeader('X-API-Version', '2.0') },
* myFlow
* );
* ```
*/
export function requireHeader(headerName, expectedValue) {
const lowerHeaderName = headerName.toLowerCase();
return async (request) => {
const value = request.headers[lowerHeaderName];
if (!value) {
throw new UserFacingError("UNAUTHENTICATED", `Missing required header: ${headerName}`);
}
if (expectedValue !== undefined && value !== expectedValue) {
throw new UserFacingError("PERMISSION_DENIED", `Invalid value for header: ${headerName}`);
}
return {};
};
}
/**
* Creates a context provider that always allows requests (no authentication).
* Useful for public endpoints.
*
* @example
* ```typescript
* export const handler = onCallGenkit(
* { contextProvider: allowAll() },
* myPublicFlow
* );
* ```
*/
export function allowAll() {
return async () => ({});
}
/**
* Combines multiple context providers. All providers must succeed.
* The returned context is a merge of all provider contexts.
*
* @example
* ```typescript
* export const handler = onCallGenkit(
* {
* contextProvider: allOf(
* requireHeader('X-Request-ID'),
* requireApiKey('X-API-Key', process.env.API_KEY!)
* )
* },
* myFlow
* );
* ```
*/
export function allOf(...providers) {
return async (request) => {
let mergedContext = {};
for (const provider of providers) {
const context = await provider(request);
mergedContext = { ...mergedContext, ...context };
}
return mergedContext;
};
}
/**
* Tries context providers in order, returning the first one that succeeds.
* If all providers fail, throws the error from the last provider.
*
* @example
* ```typescript
* // Accept either API key or Bearer token
* export const handler = onCallGenkit(
* {
* contextProvider: anyOf(
* requireApiKey('X-API-Key', process.env.API_KEY!),
* requireBearerToken(async (token) => {
* const user = await verifyJWT(token);
* return { auth: { user } };
* })
* )
* },
* myFlow
* );
* ```
*/
export function anyOf(...providers) {
return async (request) => {
let lastError;
for (const provider of providers) {
try {
const context = await provider(request);
return context;
}
catch (error) {
lastError = error instanceof Error ? error : new Error(String(error));
}
}
throw lastError || new UserFacingError("UNAUTHENTICATED", "Unauthorized");
};
}
export default onCallGenkit;
//# sourceMappingURL=aws_lambda.js.map