@ai-sdk/gateway
Version:
The Gateway provider for the [AI SDK](https://ai-sdk.dev/docs) allows the use of a wide variety of AI models and providers.
789 lines (763 loc) • 24.5 kB
JavaScript
// src/gateway-provider.ts
import { NoSuchModelError } from "@ai-sdk/provider";
import {
loadOptionalSetting,
withoutTrailingSlash
} from "@ai-sdk/provider-utils";
// src/errors/as-gateway-error.ts
import { APICallError } from "@ai-sdk/provider";
// src/errors/create-gateway-error.ts
import { z as z2 } from "zod/v4";
// src/errors/gateway-error.ts
var marker = "vercel.ai.gateway.error";
var symbol = Symbol.for(marker);
var _a, _b;
var GatewayError = class _GatewayError extends (_b = Error, _a = symbol, _b) {
constructor({
message,
statusCode = 500,
cause
}) {
super(message);
this[_a] = true;
this.statusCode = statusCode;
this.cause = cause;
}
/**
* Checks if the given error is a Gateway Error.
* @param {unknown} error - The error to check.
* @returns {boolean} True if the error is a Gateway Error, false otherwise.
*/
static isInstance(error) {
return _GatewayError.hasMarker(error);
}
static hasMarker(error) {
return typeof error === "object" && error !== null && symbol in error && error[symbol] === true;
}
};
// src/errors/gateway-authentication-error.ts
var name = "GatewayAuthenticationError";
var marker2 = `vercel.ai.gateway.error.${name}`;
var symbol2 = Symbol.for(marker2);
var _a2, _b2;
var GatewayAuthenticationError = class _GatewayAuthenticationError extends (_b2 = GatewayError, _a2 = symbol2, _b2) {
constructor({
message = "Authentication failed",
statusCode = 401,
cause
} = {}) {
super({ message, statusCode, cause });
this[_a2] = true;
// used in isInstance
this.name = name;
this.type = "authentication_error";
}
static isInstance(error) {
return GatewayError.hasMarker(error) && symbol2 in error;
}
/**
* Creates a contextual error message when authentication fails
*/
static createContextualError({
apiKeyProvided,
oidcTokenProvided,
message = "Authentication failed",
statusCode = 401,
cause
}) {
let contextualMessage;
if (apiKeyProvided) {
contextualMessage = `AI Gateway authentication failed: Invalid API key provided.
The token is expected to be provided via the 'apiKey' option or 'AI_GATEWAY_API_KEY' environment variable.`;
} else if (oidcTokenProvided) {
contextualMessage = `AI Gateway authentication failed: Invalid OIDC token provided.
The token is expected to be provided via the 'VERCEL_OIDC_TOKEN' environment variable. It expires every 12 hours.
- make sure your Vercel project settings have OIDC enabled
- if running locally with 'vercel dev', the token is automatically obtained and refreshed
- if running locally with your own dev server, run 'vercel env pull' to fetch the token
- in production/preview, the token is automatically obtained and refreshed
Alternative: Provide an API key via 'apiKey' option or 'AI_GATEWAY_API_KEY' environment variable.`;
} else {
contextualMessage = `AI Gateway authentication failed: No authentication provided.
Provide either an API key or OIDC token.
API key instructions:
The token is expected to be provided via the 'apiKey' option or 'AI_GATEWAY_API_KEY' environment variable.
OIDC token instructions:
The token is expected to be provided via the 'VERCEL_OIDC_TOKEN' environment variable. It expires every 12 hours.
- make sure your Vercel project settings have OIDC enabled
- if running locally with 'vercel dev', the token is automatically obtained and refreshed
- if running locally with your own dev server, run 'vercel env pull' to fetch the token
- in production/preview, the token is automatically obtained and refreshed`;
}
return new _GatewayAuthenticationError({
message: contextualMessage,
statusCode,
cause
});
}
};
// src/errors/gateway-invalid-request-error.ts
var name2 = "GatewayInvalidRequestError";
var marker3 = `vercel.ai.gateway.error.${name2}`;
var symbol3 = Symbol.for(marker3);
var _a3, _b3;
var GatewayInvalidRequestError = class extends (_b3 = GatewayError, _a3 = symbol3, _b3) {
constructor({
message = "Invalid request",
statusCode = 400,
cause
} = {}) {
super({ message, statusCode, cause });
this[_a3] = true;
// used in isInstance
this.name = name2;
this.type = "invalid_request_error";
}
static isInstance(error) {
return GatewayError.hasMarker(error) && symbol3 in error;
}
};
// src/errors/gateway-rate-limit-error.ts
var name3 = "GatewayRateLimitError";
var marker4 = `vercel.ai.gateway.error.${name3}`;
var symbol4 = Symbol.for(marker4);
var _a4, _b4;
var GatewayRateLimitError = class extends (_b4 = GatewayError, _a4 = symbol4, _b4) {
constructor({
message = "Rate limit exceeded",
statusCode = 429,
cause
} = {}) {
super({ message, statusCode, cause });
this[_a4] = true;
// used in isInstance
this.name = name3;
this.type = "rate_limit_exceeded";
}
static isInstance(error) {
return GatewayError.hasMarker(error) && symbol4 in error;
}
};
// src/errors/gateway-model-not-found-error.ts
import { z } from "zod/v4";
var name4 = "GatewayModelNotFoundError";
var marker5 = `vercel.ai.gateway.error.${name4}`;
var symbol5 = Symbol.for(marker5);
var modelNotFoundParamSchema = z.object({
modelId: z.string()
});
var _a5, _b5;
var GatewayModelNotFoundError = class extends (_b5 = GatewayError, _a5 = symbol5, _b5) {
constructor({
message = "Model not found",
statusCode = 404,
modelId,
cause
} = {}) {
super({ message, statusCode, cause });
this[_a5] = true;
// used in isInstance
this.name = name4;
this.type = "model_not_found";
this.modelId = modelId;
}
static isInstance(error) {
return GatewayError.hasMarker(error) && symbol5 in error;
}
};
// src/errors/gateway-internal-server-error.ts
var name5 = "GatewayInternalServerError";
var marker6 = `vercel.ai.gateway.error.${name5}`;
var symbol6 = Symbol.for(marker6);
var _a6, _b6;
var GatewayInternalServerError = class extends (_b6 = GatewayError, _a6 = symbol6, _b6) {
constructor({
message = "Internal server error",
statusCode = 500,
cause
} = {}) {
super({ message, statusCode, cause });
this[_a6] = true;
// used in isInstance
this.name = name5;
this.type = "internal_server_error";
}
static isInstance(error) {
return GatewayError.hasMarker(error) && symbol6 in error;
}
};
// src/errors/gateway-response-error.ts
var name6 = "GatewayResponseError";
var marker7 = `vercel.ai.gateway.error.${name6}`;
var symbol7 = Symbol.for(marker7);
var _a7, _b7;
var GatewayResponseError = class extends (_b7 = GatewayError, _a7 = symbol7, _b7) {
constructor({
message = "Invalid response from Gateway",
statusCode = 502,
response,
validationError,
cause
} = {}) {
super({ message, statusCode, cause });
this[_a7] = true;
// used in isInstance
this.name = name6;
this.type = "response_error";
this.response = response;
this.validationError = validationError;
}
static isInstance(error) {
return GatewayError.hasMarker(error) && symbol7 in error;
}
};
// src/errors/create-gateway-error.ts
function createGatewayErrorFromResponse({
response,
statusCode,
defaultMessage = "Gateway request failed",
cause,
authMethod
}) {
const parseResult = gatewayErrorResponseSchema.safeParse(response);
if (!parseResult.success) {
return new GatewayResponseError({
message: `Invalid error response format: ${defaultMessage}`,
statusCode,
response,
validationError: parseResult.error,
cause
});
}
const validatedResponse = parseResult.data;
const errorType = validatedResponse.error.type;
const message = validatedResponse.error.message;
switch (errorType) {
case "authentication_error":
return GatewayAuthenticationError.createContextualError({
apiKeyProvided: authMethod === "api-key",
oidcTokenProvided: authMethod === "oidc",
statusCode,
cause
});
case "invalid_request_error":
return new GatewayInvalidRequestError({ message, statusCode, cause });
case "rate_limit_exceeded":
return new GatewayRateLimitError({ message, statusCode, cause });
case "model_not_found": {
const modelResult = modelNotFoundParamSchema.safeParse(
validatedResponse.error.param
);
return new GatewayModelNotFoundError({
message,
statusCode,
modelId: modelResult.success ? modelResult.data.modelId : void 0,
cause
});
}
case "internal_server_error":
return new GatewayInternalServerError({ message, statusCode, cause });
default:
return new GatewayInternalServerError({ message, statusCode, cause });
}
}
var gatewayErrorResponseSchema = z2.object({
error: z2.object({
message: z2.string(),
type: z2.string().nullish(),
param: z2.unknown().nullish(),
code: z2.union([z2.string(), z2.number()]).nullish()
})
});
// src/errors/as-gateway-error.ts
function asGatewayError(error, authMethod) {
var _a8;
if (GatewayError.isInstance(error)) {
return error;
}
if (APICallError.isInstance(error)) {
return createGatewayErrorFromResponse({
response: extractApiCallResponse(error),
statusCode: (_a8 = error.statusCode) != null ? _a8 : 500,
defaultMessage: "Gateway request failed",
cause: error,
authMethod
});
}
return createGatewayErrorFromResponse({
response: {},
statusCode: 500,
defaultMessage: error instanceof Error ? `Gateway request failed: ${error.message}` : "Unknown Gateway error",
cause: error,
authMethod
});
}
// src/errors/extract-api-call-response.ts
function extractApiCallResponse(error) {
if (error.data !== void 0) {
return error.data;
}
if (error.responseBody != null) {
try {
return JSON.parse(error.responseBody);
} catch (e) {
return error.responseBody;
}
}
return {};
}
// src/errors/parse-auth-method.ts
import { z as z3 } from "zod/v4";
var GATEWAY_AUTH_METHOD_HEADER = "ai-gateway-auth-method";
function parseAuthMethod(headers) {
const result = gatewayAuthMethodSchema.safeParse(
headers[GATEWAY_AUTH_METHOD_HEADER]
);
return result.success ? result.data : void 0;
}
var gatewayAuthMethodSchema = z3.union([
z3.literal("api-key"),
z3.literal("oidc")
]);
// src/gateway-fetch-metadata.ts
import {
createJsonErrorResponseHandler,
createJsonResponseHandler,
getFromApi,
resolve
} from "@ai-sdk/provider-utils";
import { z as z4 } from "zod/v4";
var GatewayFetchMetadata = class {
constructor(config) {
this.config = config;
}
async getAvailableModels() {
try {
const { value } = await getFromApi({
url: `${this.config.baseURL}/config`,
headers: await resolve(this.config.headers()),
successfulResponseHandler: createJsonResponseHandler(
gatewayFetchMetadataSchema
),
failedResponseHandler: createJsonErrorResponseHandler({
errorSchema: z4.any(),
errorToMessage: (data) => data
}),
fetch: this.config.fetch
});
return value;
} catch (error) {
throw asGatewayError(error);
}
}
};
var gatewayLanguageModelSpecificationSchema = z4.object({
specificationVersion: z4.literal("v2"),
provider: z4.string(),
modelId: z4.string()
});
var gatewayLanguageModelPricingSchema = z4.object({
input: z4.string(),
output: z4.string()
});
var gatewayLanguageModelEntrySchema = z4.object({
id: z4.string(),
name: z4.string(),
description: z4.string().nullish(),
pricing: gatewayLanguageModelPricingSchema.nullish(),
specification: gatewayLanguageModelSpecificationSchema
});
var gatewayFetchMetadataSchema = z4.object({
models: z4.array(gatewayLanguageModelEntrySchema)
});
// src/gateway-language-model.ts
import {
combineHeaders,
createEventSourceResponseHandler,
createJsonErrorResponseHandler as createJsonErrorResponseHandler2,
createJsonResponseHandler as createJsonResponseHandler2,
postJsonToApi,
resolve as resolve2
} from "@ai-sdk/provider-utils";
import { z as z5 } from "zod/v4";
var GatewayLanguageModel = class {
constructor(modelId, config) {
this.modelId = modelId;
this.config = config;
this.specificationVersion = "v2";
this.supportedUrls = { "*/*": [/.*/] };
}
get provider() {
return this.config.provider;
}
async getArgs(options) {
const { abortSignal: _abortSignal, ...optionsWithoutSignal } = options;
return {
args: this.maybeEncodeFileParts(optionsWithoutSignal),
warnings: []
};
}
async doGenerate(options) {
const { args, warnings } = await this.getArgs(options);
const { abortSignal } = options;
const resolvedHeaders = await resolve2(this.config.headers());
try {
const {
responseHeaders,
value: responseBody,
rawValue: rawResponse
} = await postJsonToApi({
url: this.getUrl(),
headers: combineHeaders(
resolvedHeaders,
options.headers,
this.getModelConfigHeaders(this.modelId, false),
await resolve2(this.config.o11yHeaders)
),
body: args,
successfulResponseHandler: createJsonResponseHandler2(z5.any()),
failedResponseHandler: createJsonErrorResponseHandler2({
errorSchema: z5.any(),
errorToMessage: (data) => data
}),
...abortSignal && { abortSignal },
fetch: this.config.fetch
});
return {
...responseBody,
request: { body: args },
response: { headers: responseHeaders, body: rawResponse },
warnings
};
} catch (error) {
throw asGatewayError(error, parseAuthMethod(resolvedHeaders));
}
}
async doStream(options) {
const { args, warnings } = await this.getArgs(options);
const { abortSignal } = options;
const resolvedHeaders = await resolve2(this.config.headers());
try {
const { value: response, responseHeaders } = await postJsonToApi({
url: this.getUrl(),
headers: combineHeaders(
resolvedHeaders,
options.headers,
this.getModelConfigHeaders(this.modelId, true),
await resolve2(this.config.o11yHeaders)
),
body: args,
successfulResponseHandler: createEventSourceResponseHandler(z5.any()),
failedResponseHandler: createJsonErrorResponseHandler2({
errorSchema: z5.any(),
errorToMessage: (data) => data
}),
...abortSignal && { abortSignal },
fetch: this.config.fetch
});
return {
stream: response.pipeThrough(
new TransformStream({
start(controller) {
if (warnings.length > 0) {
controller.enqueue({ type: "stream-start", warnings });
}
},
transform(chunk, controller) {
if (chunk.success) {
const streamPart = chunk.value;
if (streamPart.type === "raw" && !options.includeRawChunks) {
return;
}
if (streamPart.type === "response-metadata" && streamPart.timestamp && typeof streamPart.timestamp === "string") {
streamPart.timestamp = new Date(streamPart.timestamp);
}
controller.enqueue(streamPart);
} else {
controller.error(
chunk.error
);
}
}
})
),
request: { body: args },
response: { headers: responseHeaders }
};
} catch (error) {
throw asGatewayError(error, parseAuthMethod(resolvedHeaders));
}
}
isFilePart(part) {
return part && typeof part === "object" && "type" in part && part.type === "file";
}
/**
* Encodes file parts in the prompt to base64. Mutates the passed options
* instance directly to avoid copying the file data.
* @param options - The options to encode.
* @returns The options with the file parts encoded.
*/
maybeEncodeFileParts(options) {
for (const message of options.prompt) {
for (const part of message.content) {
if (this.isFilePart(part)) {
const filePart = part;
if (filePart.data instanceof Uint8Array) {
const buffer = Uint8Array.from(filePart.data);
const base64Data = Buffer.from(buffer).toString("base64");
filePart.data = new URL(
`data:${filePart.mediaType || "application/octet-stream"};base64,${base64Data}`
);
}
}
}
}
return options;
}
getUrl() {
return `${this.config.baseURL}/language-model`;
}
getModelConfigHeaders(modelId, streaming) {
return {
"ai-language-model-specification-version": "2",
"ai-language-model-id": modelId,
"ai-language-model-streaming": String(streaming)
};
}
};
// src/gateway-embedding-model.ts
import {
combineHeaders as combineHeaders2,
createJsonResponseHandler as createJsonResponseHandler3,
createJsonErrorResponseHandler as createJsonErrorResponseHandler3,
postJsonToApi as postJsonToApi2,
resolve as resolve3
} from "@ai-sdk/provider-utils";
import { z as z6 } from "zod/v4";
var GatewayEmbeddingModel = class {
constructor(modelId, config) {
this.modelId = modelId;
this.config = config;
this.specificationVersion = "v2";
this.maxEmbeddingsPerCall = 2048;
this.supportsParallelCalls = true;
}
get provider() {
return this.config.provider;
}
async doEmbed({
values,
headers,
abortSignal,
providerOptions
}) {
var _a8;
const resolvedHeaders = await resolve3(this.config.headers());
try {
const {
responseHeaders,
value: responseBody,
rawValue
} = await postJsonToApi2({
url: this.getUrl(),
headers: combineHeaders2(
resolvedHeaders,
headers != null ? headers : {},
this.getModelConfigHeaders(),
await resolve3(this.config.o11yHeaders)
),
body: {
input: values.length === 1 ? values[0] : values,
...providerOptions != null ? providerOptions : {}
},
successfulResponseHandler: createJsonResponseHandler3(
gatewayEmbeddingResponseSchema
),
failedResponseHandler: createJsonErrorResponseHandler3({
errorSchema: z6.any(),
errorToMessage: (data) => data
}),
...abortSignal && { abortSignal },
fetch: this.config.fetch
});
return {
embeddings: responseBody.embeddings,
usage: (_a8 = responseBody.usage) != null ? _a8 : void 0,
providerMetadata: responseBody.providerMetadata,
response: { headers: responseHeaders, body: rawValue }
};
} catch (error) {
throw asGatewayError(error, parseAuthMethod(resolvedHeaders));
}
}
getUrl() {
return `${this.config.baseURL}/embedding-model`;
}
getModelConfigHeaders() {
return {
"ai-embedding-model-specification-version": "2",
"ai-model-id": this.modelId
};
}
};
var gatewayEmbeddingResponseSchema = z6.object({
embeddings: z6.array(z6.array(z6.number())),
usage: z6.object({ tokens: z6.number() }).nullish(),
providerMetadata: z6.record(z6.string(), z6.record(z6.string(), z6.unknown())).optional()
});
// src/vercel-environment.ts
async function getVercelOidcToken() {
var _a8, _b8;
const token = (_b8 = (_a8 = getContext().headers) == null ? void 0 : _a8["x-vercel-oidc-token"]) != null ? _b8 : process.env.VERCEL_OIDC_TOKEN;
if (!token) {
throw new GatewayAuthenticationError({
message: "OIDC token not available",
statusCode: 401
});
}
return token;
}
async function getVercelRequestId() {
var _a8;
return (_a8 = getContext().headers) == null ? void 0 : _a8["x-vercel-id"];
}
var SYMBOL_FOR_REQ_CONTEXT = Symbol.for("@vercel/request-context");
function getContext() {
var _a8, _b8, _c;
const fromSymbol = globalThis;
return (_c = (_b8 = (_a8 = fromSymbol[SYMBOL_FOR_REQ_CONTEXT]) == null ? void 0 : _a8.get) == null ? void 0 : _b8.call(_a8)) != null ? _c : {};
}
// src/gateway-provider.ts
var AI_GATEWAY_PROTOCOL_VERSION = "0.0.1";
function createGatewayProvider(options = {}) {
var _a8, _b8;
let pendingMetadata = null;
let metadataCache = null;
const cacheRefreshMillis = (_a8 = options.metadataCacheRefreshMillis) != null ? _a8 : 1e3 * 60 * 5;
let lastFetchTime = 0;
const baseURL = (_b8 = withoutTrailingSlash(options.baseURL)) != null ? _b8 : "https://ai-gateway.vercel.sh/v1/ai";
const getHeaders = async () => {
const auth = await getGatewayAuthToken(options);
if (auth) {
return {
Authorization: `Bearer ${auth.token}`,
"ai-gateway-protocol-version": AI_GATEWAY_PROTOCOL_VERSION,
[GATEWAY_AUTH_METHOD_HEADER]: auth.authMethod,
...options.headers
};
}
throw GatewayAuthenticationError.createContextualError({
apiKeyProvided: false,
oidcTokenProvided: false,
statusCode: 401
});
};
const createO11yHeaders = () => {
const deploymentId = loadOptionalSetting({
settingValue: void 0,
environmentVariableName: "VERCEL_DEPLOYMENT_ID"
});
const environment = loadOptionalSetting({
settingValue: void 0,
environmentVariableName: "VERCEL_ENV"
});
const region = loadOptionalSetting({
settingValue: void 0,
environmentVariableName: "VERCEL_REGION"
});
return async () => {
const requestId = await getVercelRequestId();
return {
...deploymentId && { "ai-o11y-deployment-id": deploymentId },
...environment && { "ai-o11y-environment": environment },
...region && { "ai-o11y-region": region },
...requestId && { "ai-o11y-request-id": requestId }
};
};
};
const createLanguageModel = (modelId) => {
return new GatewayLanguageModel(modelId, {
provider: "gateway",
baseURL,
headers: getHeaders,
fetch: options.fetch,
o11yHeaders: createO11yHeaders()
});
};
const getAvailableModels = async () => {
var _a9, _b9, _c;
const now = (_c = (_b9 = (_a9 = options._internal) == null ? void 0 : _a9.currentDate) == null ? void 0 : _b9.call(_a9).getTime()) != null ? _c : Date.now();
if (!pendingMetadata || now - lastFetchTime > cacheRefreshMillis) {
lastFetchTime = now;
pendingMetadata = new GatewayFetchMetadata({
baseURL,
headers: getHeaders,
fetch: options.fetch
}).getAvailableModels().then((metadata) => {
metadataCache = metadata;
return metadata;
}).catch(async (error) => {
throw asGatewayError(error, parseAuthMethod(await getHeaders()));
});
}
return metadataCache ? Promise.resolve(metadataCache) : pendingMetadata;
};
const provider = function(modelId) {
if (new.target) {
throw new Error(
"The Gateway Provider model function cannot be called with the new keyword."
);
}
return createLanguageModel(modelId);
};
provider.getAvailableModels = getAvailableModels;
provider.imageModel = (modelId) => {
throw new NoSuchModelError({ modelId, modelType: "imageModel" });
};
provider.languageModel = createLanguageModel;
provider.textEmbeddingModel = (modelId) => {
return new GatewayEmbeddingModel(modelId, {
provider: "gateway",
baseURL,
headers: getHeaders,
fetch: options.fetch,
o11yHeaders: createO11yHeaders()
});
};
return provider;
}
var gateway = createGatewayProvider();
async function getGatewayAuthToken(options) {
const apiKey = loadOptionalSetting({
settingValue: options.apiKey,
environmentVariableName: "AI_GATEWAY_API_KEY"
});
if (apiKey) {
return {
token: apiKey,
authMethod: "api-key"
};
}
try {
const oidcToken = await getVercelOidcToken();
return {
token: oidcToken,
authMethod: "oidc"
};
} catch (e) {
return null;
}
}
export {
GatewayAuthenticationError,
GatewayError,
GatewayInternalServerError,
GatewayInvalidRequestError,
GatewayModelNotFoundError,
GatewayRateLimitError,
GatewayResponseError,
createGatewayProvider as createGateway,
createGatewayProvider,
gateway
};
//# sourceMappingURL=index.mjs.map