UNPKG

@ai-sdk/deepinfra

Version:

The **[DeepInfra provider](https://ai-sdk.dev/providers/ai-sdk-providers/deepinfra)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the DeepInfra API, giving you access to models like Llama 3, Mixtral, and other state-of-th

357 lines (351 loc) 13.1 kB
"use strict"; var __defProp = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames = Object.getOwnPropertyNames; var __hasOwnProp = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames(from)) if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod); // src/index.ts var index_exports = {}; __export(index_exports, { VERSION: () => VERSION, createDeepInfra: () => createDeepInfra, deepinfra: () => deepinfra }); module.exports = __toCommonJS(index_exports); // src/deepinfra-provider.ts var import_openai_compatible2 = require("@ai-sdk/openai-compatible"); var import_provider_utils2 = require("@ai-sdk/provider-utils"); // src/deepinfra-image-model.ts var import_provider_utils = require("@ai-sdk/provider-utils"); var import_v4 = require("zod/v4"); var DeepInfraImageModel = class { constructor(modelId, config) { this.modelId = modelId; this.config = config; this.specificationVersion = "v3"; this.maxImagesPerCall = 1; } get provider() { return this.config.provider; } async doGenerate({ prompt, n, size, aspectRatio, seed, providerOptions, headers, abortSignal, files, mask }) { var _a, _b, _c, _d, _e; const warnings = []; const currentDate = (_c = (_b = (_a = this.config._internal) == null ? void 0 : _a.currentDate) == null ? void 0 : _b.call(_a)) != null ? _c : /* @__PURE__ */ new Date(); if (files != null && files.length > 0) { const { value: response2, responseHeaders: responseHeaders2 } = await (0, import_provider_utils.postFormDataToApi)({ url: this.getEditUrl(), headers: (0, import_provider_utils.combineHeaders)(this.config.headers(), headers), formData: (0, import_provider_utils.convertToFormData)( { model: this.modelId, prompt, image: await Promise.all(files.map((file) => fileToBlob(file))), mask: mask != null ? await fileToBlob(mask) : void 0, n, size, ...(_d = providerOptions.deepinfra) != null ? _d : {} }, { useArrayBrackets: false } ), failedResponseHandler: (0, import_provider_utils.createJsonErrorResponseHandler)({ errorSchema: deepInfraEditErrorSchema, errorToMessage: (error) => { var _a2, _b2; return (_b2 = (_a2 = error.error) == null ? void 0 : _a2.message) != null ? _b2 : "Unknown error"; } }), successfulResponseHandler: (0, import_provider_utils.createJsonResponseHandler)( deepInfraEditResponseSchema ), abortSignal, fetch: this.config.fetch }); return { images: response2.data.map((item) => item.b64_json), warnings, response: { timestamp: currentDate, modelId: this.modelId, headers: responseHeaders2 } }; } const splitSize = size == null ? void 0 : size.split("x"); const { value: response, responseHeaders } = await (0, import_provider_utils.postJsonToApi)({ url: `${this.config.baseURL}/${this.modelId}`, headers: (0, import_provider_utils.combineHeaders)(this.config.headers(), headers), body: { prompt, num_images: n, ...aspectRatio && { aspect_ratio: aspectRatio }, ...splitSize && { width: splitSize[0], height: splitSize[1] }, ...seed != null && { seed }, ...(_e = providerOptions.deepinfra) != null ? _e : {} }, failedResponseHandler: (0, import_provider_utils.createJsonErrorResponseHandler)({ errorSchema: deepInfraErrorSchema, errorToMessage: (error) => error.detail.error }), successfulResponseHandler: (0, import_provider_utils.createJsonResponseHandler)( deepInfraImageResponseSchema ), abortSignal, fetch: this.config.fetch }); return { images: response.images.map( (image) => image.replace(/^data:image\/\w+;base64,/, "") ), warnings, response: { timestamp: currentDate, modelId: this.modelId, headers: responseHeaders } }; } getEditUrl() { const baseUrl = this.config.baseURL.replace("/inference", "/openai"); return `${baseUrl}/images/edits`; } }; var deepInfraErrorSchema = import_v4.z.object({ detail: import_v4.z.object({ error: import_v4.z.string() }) }); var deepInfraImageResponseSchema = import_v4.z.object({ images: import_v4.z.array(import_v4.z.string()) }); var deepInfraEditErrorSchema = import_v4.z.object({ error: import_v4.z.object({ message: import_v4.z.string() }).optional() }); var deepInfraEditResponseSchema = import_v4.z.object({ data: import_v4.z.array(import_v4.z.object({ b64_json: import_v4.z.string() })) }); async function fileToBlob(file) { if (file.type === "url") { return (0, import_provider_utils.downloadBlob)(file.url); } const data = file.data instanceof Uint8Array ? file.data : (0, import_provider_utils.convertBase64ToUint8Array)(file.data); return new Blob([data], { type: file.mediaType }); } // src/deepinfra-chat-language-model.ts var import_openai_compatible = require("@ai-sdk/openai-compatible"); var DeepInfraChatLanguageModel = class extends import_openai_compatible.OpenAICompatibleChatLanguageModel { constructor(modelId, config) { super(modelId, config); } /** * Fixes incorrect token usage for Gemini/Gemma models from DeepInfra. * * DeepInfra's API returns completion_tokens that don't include reasoning_tokens * for Gemini/Gemma models, which violates the OpenAI-compatible spec. * According to the spec, completion_tokens should include reasoning_tokens. * * Example of incorrect data from DeepInfra: * { * "completion_tokens": 84, // text-only tokens * "completion_tokens_details": { * "reasoning_tokens": 1081 // reasoning tokens not included above * } * } * * This would result in negative text tokens: 84 - 1081 = -997 * * The fix: If reasoning_tokens > completion_tokens, add reasoning_tokens * to completion_tokens: 84 + 1081 = 1165 */ fixUsageForGeminiModels(usage) { var _a, _b; if (!usage || !((_a = usage.completion_tokens_details) == null ? void 0 : _a.reasoning_tokens)) { return usage; } const completionTokens = (_b = usage.completion_tokens) != null ? _b : 0; const reasoningTokens = usage.completion_tokens_details.reasoning_tokens; if (reasoningTokens > completionTokens) { const correctedCompletionTokens = completionTokens + reasoningTokens; return { ...usage, // Add reasoning_tokens to completion_tokens to get the correct total completion_tokens: correctedCompletionTokens, // Update total_tokens if present total_tokens: usage.total_tokens != null ? usage.total_tokens + reasoningTokens : void 0 }; } return usage; } async doGenerate(options) { var _a, _b, _c, _d, _e, _f, _g; const result = await super.doGenerate(options); if ((_a = result.usage) == null ? void 0 : _a.raw) { const fixedRawUsage = this.fixUsageForGeminiModels(result.usage.raw); if (fixedRawUsage !== result.usage.raw) { const promptTokens = (_b = fixedRawUsage.prompt_tokens) != null ? _b : 0; const completionTokens = (_c = fixedRawUsage.completion_tokens) != null ? _c : 0; const cacheReadTokens = (_e = (_d = fixedRawUsage.prompt_tokens_details) == null ? void 0 : _d.cached_tokens) != null ? _e : 0; const reasoningTokens = (_g = (_f = fixedRawUsage.completion_tokens_details) == null ? void 0 : _f.reasoning_tokens) != null ? _g : 0; return { ...result, usage: { inputTokens: { total: promptTokens, noCache: promptTokens - cacheReadTokens, cacheRead: cacheReadTokens, cacheWrite: void 0 }, outputTokens: { total: completionTokens, text: completionTokens - reasoningTokens, reasoning: reasoningTokens }, raw: fixedRawUsage } }; } } return result; } async doStream(options) { const result = await super.doStream(options); const originalStream = result.stream; const fixUsage = this.fixUsageForGeminiModels.bind(this); const transformedStream = new ReadableStream({ async start(controller) { var _a, _b, _c, _d, _e, _f, _g; const reader = originalStream.getReader(); try { while (true) { const { done, value } = await reader.read(); if (done) break; if (value.type === "finish" && ((_a = value.usage) == null ? void 0 : _a.raw)) { const fixedRawUsage = fixUsage(value.usage.raw); if (fixedRawUsage !== value.usage.raw) { const promptTokens = (_b = fixedRawUsage.prompt_tokens) != null ? _b : 0; const completionTokens = (_c = fixedRawUsage.completion_tokens) != null ? _c : 0; const cacheReadTokens = (_e = (_d = fixedRawUsage.prompt_tokens_details) == null ? void 0 : _d.cached_tokens) != null ? _e : 0; const reasoningTokens = (_g = (_f = fixedRawUsage.completion_tokens_details) == null ? void 0 : _f.reasoning_tokens) != null ? _g : 0; controller.enqueue({ ...value, usage: { inputTokens: { total: promptTokens, noCache: promptTokens - cacheReadTokens, cacheRead: cacheReadTokens, cacheWrite: void 0 }, outputTokens: { total: completionTokens, text: completionTokens - reasoningTokens, reasoning: reasoningTokens }, raw: fixedRawUsage } }); } else { controller.enqueue(value); } } else { controller.enqueue(value); } } controller.close(); } catch (error) { controller.error(error); } } }); return { ...result, stream: transformedStream }; } }; // src/version.ts var VERSION = true ? "2.0.52" : "0.0.0-test"; // src/deepinfra-provider.ts function createDeepInfra(options = {}) { var _a; const baseURL = (0, import_provider_utils2.withoutTrailingSlash)( (_a = options.baseURL) != null ? _a : "https://api.deepinfra.com/v1" ); const getHeaders = () => (0, import_provider_utils2.withUserAgentSuffix)( { Authorization: `Bearer ${(0, import_provider_utils2.loadApiKey)({ apiKey: options.apiKey, environmentVariableName: "DEEPINFRA_API_KEY", description: "DeepInfra's API key" })}`, ...options.headers }, `ai-sdk/deepinfra/${VERSION}` ); const getCommonModelConfig = (modelType) => ({ provider: `deepinfra.${modelType}`, url: ({ path }) => `${baseURL}/openai${path}`, headers: getHeaders, fetch: options.fetch }); const createChatModel = (modelId) => { return new DeepInfraChatLanguageModel( modelId, getCommonModelConfig("chat") ); }; const createCompletionModel = (modelId) => new import_openai_compatible2.OpenAICompatibleCompletionLanguageModel( modelId, getCommonModelConfig("completion") ); const createEmbeddingModel = (modelId) => new import_openai_compatible2.OpenAICompatibleEmbeddingModel( modelId, getCommonModelConfig("embedding") ); const createImageModel = (modelId) => new DeepInfraImageModel(modelId, { ...getCommonModelConfig("image"), baseURL: baseURL ? `${baseURL}/inference` : "https://api.deepinfra.com/v1/inference" }); const provider = (modelId) => createChatModel(modelId); provider.specificationVersion = "v3"; provider.completionModel = createCompletionModel; provider.chatModel = createChatModel; provider.image = createImageModel; provider.imageModel = createImageModel; provider.languageModel = createChatModel; provider.embeddingModel = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; return provider; } var deepinfra = createDeepInfra(); // Annotate the CommonJS export names for ESM import in node: 0 && (module.exports = { VERSION, createDeepInfra, deepinfra }); //# sourceMappingURL=index.js.map