UNPKG

ai-gateway-provider

Version:

AI Gateway Provider for AI-SDK

274 lines (271 loc) 9.39 kB
var __defProp = Object.defineProperty; var __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value; var __publicField = (obj, key, value) => __defNormalProp(obj, typeof key !== "symbol" ? key + "" : key, value); // src/providers.ts var providers = [ { name: "openai", regex: /^https:\/\/api\.openai\.com\//, transformEndpoint: (url) => url.replace(/^https:\/\/api\.openai\.com\//, "") }, { name: "deepseek", regex: /^https:\/\/api\.deepseek\.com\//, transformEndpoint: (url) => url.replace(/^https:\/\/api\.deepseek\.com\//, "") }, { name: "anthropic", regex: /^https:\/\/api\.anthropic\.com\//, transformEndpoint: (url) => url.replace(/^https:\/\/api\.anthropic\.com\//, ""), headerKey: "x-api-key" }, { name: "google-ai-studio", regex: /^https:\/\/generativelanguage\.googleapis\.com\//, headerKey: "x-goog-api-key", transformEndpoint: (url) => url.replace(/^https:\/\/generativelanguage\.googleapis\.com\//, "") }, { name: "grok", regex: /^https:\/\/api\.x\.ai\//, transformEndpoint: (url) => url.replace(/^https:\/\/api\.x\.ai\//, "") }, { name: "mistral", regex: /^https:\/\/api\.mistral\.ai\//, transformEndpoint: (url) => url.replace(/^https:\/\/api\.mistral\.ai\//, "") }, { name: "perplexity-ai", regex: /^https:\/\/api\.perplexity\.ai\//, transformEndpoint: (url) => url.replace(/^https:\/\/api\.perplexity\.ai\//, "") }, { name: "replicate", regex: /^https:\/\/api\.replicate\.com\//, transformEndpoint: (url) => url.replace(/^https:\/\/api\.replicate\.com\//, "") }, { name: "groq", regex: /^https:\/\/api\.groq\.com\/openai\/v1\//, transformEndpoint: (url) => url.replace(/^https:\/\/api\.groq\.com\/openai\/v1\//, "") }, { name: "azure-openai", regex: /^https:\/\/(?<resource>[^.]+)\.openai\.azure\.com\/openai\/deployments\/(?<deployment>[^/]+)\/(?<rest>.*)$/, transformEndpoint: (url) => { const match = url.match( /^https:\/\/(?<resource>[^.]+)\.openai\.azure\.com\/openai\/deployments\/(?<deployment>[^/]+)\/(?<rest>.*)$/ ); if (!match || !match.groups) return url; const { resource, deployment, rest } = match.groups; if (!resource || !deployment || !rest) { throw new Error("Failed to parse Azure OpenAI endpoint URL."); } return `${resource}/${deployment}/${rest}`; }, headerKey: "api-key" } ]; // src/auth.ts var CF_TEMP_TOKEN = "CF_TEMP_TOKEN"; // src/index.ts var AiGatewayInternalFetchError = class extends Error { }; var AiGatewayDoesNotExist = class extends Error { }; var AiGatewayUnauthorizedError = class extends Error { }; async function streamToObject(stream) { const response = new Response(stream); return await response.json(); } var AiGatewayChatLanguageModel = class { constructor(models, config) { __publicField(this, "specificationVersion", "v2"); __publicField(this, "defaultObjectGenerationMode", "json"); __publicField(this, "supportedUrls", { // No URLS are supported for this language model }); __publicField(this, "models"); __publicField(this, "config"); this.models = models; this.config = config; } get modelId() { if (!this.models[0]) { throw new Error("models cannot be empty array"); } return this.models[0].modelId; } get provider() { if (!this.models[0]) { throw new Error("models cannot be empty array"); } return this.models[0].provider; } async processModelRequest(options, modelMethod) { const requests = []; for (const model of this.models) { if (!model.config || !Object.keys(model.config).includes("fetch")) { throw new Error( `Sorry, but provider "${model.provider}" is currently not supported, please open a issue in the github repo!` ); } model.config.fetch = (url, request) => { requests.push({ modelProvider: model.provider, request, url }); throw new AiGatewayInternalFetchError("Stopping provider execution..."); }; try { await model[modelMethod](options); } catch (e) { if (!(e instanceof AiGatewayInternalFetchError)) { throw e; } } } const body = await Promise.all( requests.map(async (req) => { let providerConfig = null; for (const provider of providers) { if (provider.regex.test(req.url)) { providerConfig = provider; } } if (!providerConfig) { throw new Error( `Sorry, but provider "${req.modelProvider}" is currently not supported, please open a issue in the github repo!` ); } if (!req.request.body) { throw new Error("Ai Gateway provider received an unexpected empty body"); } const authHeader = providerConfig.headerKey ?? "authorization"; const authValue = "get" in req.request.headers ? req.request.headers.get(authHeader) : req.request.headers[authHeader]; if (authValue?.indexOf(CF_TEMP_TOKEN) !== -1) { if ("delete" in req.request.headers) { req.request.headers.delete(authHeader); } else { delete req.request.headers[authHeader]; } } return { endpoint: providerConfig.transformEndpoint(req.url), headers: req.request.headers, provider: providerConfig.name, query: await streamToObject(req.request.body) }; }) ); const headers = parseAiGatewayOptions(this.config.options ?? {}); let resp; if ("binding" in this.config) { const updatedBody = body.map((obj) => ({ ...obj, headers: { ...obj.headers ?? {}, ...Object.fromEntries(headers.entries()) } })); resp = await this.config.binding.run(updatedBody); } else { headers.set("Content-Type", "application/json"); headers.set("cf-aig-authorization", `Bearer ${this.config.apiKey}`); resp = await fetch( `https://gateway.ai.cloudflare.com/v1/${this.config.accountId}/${this.config.gateway}`, { body: JSON.stringify(body), headers, method: "POST" } ); } if (resp.status === 400) { const cloneResp = resp.clone(); const result = await cloneResp.json(); if (result.success === false && result.error && result.error.length > 0 && result.error[0]?.code === 2001) { throw new AiGatewayDoesNotExist("This AI gateway does not exist"); } } else if (resp.status === 401) { const cloneResp = resp.clone(); const result = await cloneResp.json(); if (result.success === false && result.error && result.error.length > 0 && result.error[0]?.code === 2009) { throw new AiGatewayUnauthorizedError( "Your AI Gateway has authentication active, but you didn't provide a valid apiKey" ); } } const step = Number.parseInt(resp.headers.get("cf-aig-step") ?? "0", 10); if (!this.models[step]) { throw new Error("Unexpected AI Gateway Error"); } this.models[step].config = { ...this.models[step].config, fetch: (_url, _req) => resp }; return this.models[step][modelMethod](options); } async doStream(options) { return this.processModelRequest(options, "doStream"); } async doGenerate(options) { return this.processModelRequest(options, "doGenerate"); } }; function createAiGateway(options) { const createChatModel = (models) => { return new AiGatewayChatLanguageModel(Array.isArray(models) ? models : [models], options); }; const provider = (models) => createChatModel(models); provider.chat = createChatModel; return provider; } function parseAiGatewayOptions(options) { const headers = new Headers(); if (options.skipCache === true) { headers.set("cf-skip-cache", "true"); } if (options.cacheTtl) { headers.set("cf-cache-ttl", options.cacheTtl.toString()); } if (options.metadata) { headers.set("cf-aig-metadata", JSON.stringify(options.metadata)); } if (options.cacheKey) { headers.set("cf-aig-cache-key", options.cacheKey); } if (options.collectLog !== void 0) { headers.set("cf-aig-collect-log", options.collectLog === true ? "true" : "false"); } if (options.eventId !== void 0) { headers.set("cf-aig-event-id", options.eventId); } if (options.requestTimeoutMs !== void 0) { headers.set("cf-aig-request-timeout", options.requestTimeoutMs.toString()); } if (options.retries !== void 0) { if (options.retries.maxAttempts !== void 0) { headers.set("cf-aig-max-attempts", options.retries.maxAttempts.toString()); } if (options.retries.retryDelayMs !== void 0) { headers.set("cf-aig-retry-delay", options.retries.retryDelayMs.toString()); } if (options.retries.backoff !== void 0) { headers.set("cf-aig-backoff", options.retries.backoff); } } return headers; } export { AiGatewayChatLanguageModel, AiGatewayDoesNotExist, AiGatewayInternalFetchError, AiGatewayUnauthorizedError, createAiGateway, parseAiGatewayOptions }; //# sourceMappingURL=index.mjs.map