UNPKG

@ai-sdk/deepinfra

Version:

The **[DeepInfra provider](https://ai-sdk.dev/providers/ai-sdk-providers/deepinfra)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the DeepInfra API, giving you access to models like Llama 3, Mixtral, and other state-of-th

344 lines (340 loc) 11.7 kB
// src/deepinfra-provider.ts import { OpenAICompatibleCompletionLanguageModel, OpenAICompatibleEmbeddingModel } from "@ai-sdk/openai-compatible"; import { loadApiKey, withoutTrailingSlash, withUserAgentSuffix } from "@ai-sdk/provider-utils"; // src/deepinfra-image-model.ts import { combineHeaders, convertBase64ToUint8Array, convertToFormData, createJsonErrorResponseHandler, createJsonResponseHandler, downloadBlob, postFormDataToApi, postJsonToApi } from "@ai-sdk/provider-utils"; import { z } from "zod/v4"; var DeepInfraImageModel = class { constructor(modelId, config) { this.modelId = modelId; this.config = config; this.specificationVersion = "v3"; this.maxImagesPerCall = 1; } get provider() { return this.config.provider; } async doGenerate({ prompt, n, size, aspectRatio, seed, providerOptions, headers, abortSignal, files, mask }) { var _a, _b, _c, _d, _e; const warnings = []; const currentDate = (_c = (_b = (_a = this.config._internal) == null ? void 0 : _a.currentDate) == null ? void 0 : _b.call(_a)) != null ? _c : /* @__PURE__ */ new Date(); if (files != null && files.length > 0) { const { value: response2, responseHeaders: responseHeaders2 } = await postFormDataToApi({ url: this.getEditUrl(), headers: combineHeaders(this.config.headers(), headers), formData: convertToFormData( { model: this.modelId, prompt, image: await Promise.all(files.map((file) => fileToBlob(file))), mask: mask != null ? await fileToBlob(mask) : void 0, n, size, ...(_d = providerOptions.deepinfra) != null ? _d : {} }, { useArrayBrackets: false } ), failedResponseHandler: createJsonErrorResponseHandler({ errorSchema: deepInfraEditErrorSchema, errorToMessage: (error) => { var _a2, _b2; return (_b2 = (_a2 = error.error) == null ? void 0 : _a2.message) != null ? _b2 : "Unknown error"; } }), successfulResponseHandler: createJsonResponseHandler( deepInfraEditResponseSchema ), abortSignal, fetch: this.config.fetch }); return { images: response2.data.map((item) => item.b64_json), warnings, response: { timestamp: currentDate, modelId: this.modelId, headers: responseHeaders2 } }; } const splitSize = size == null ? void 0 : size.split("x"); const { value: response, responseHeaders } = await postJsonToApi({ url: `${this.config.baseURL}/${this.modelId}`, headers: combineHeaders(this.config.headers(), headers), body: { prompt, num_images: n, ...aspectRatio && { aspect_ratio: aspectRatio }, ...splitSize && { width: splitSize[0], height: splitSize[1] }, ...seed != null && { seed }, ...(_e = providerOptions.deepinfra) != null ? _e : {} }, failedResponseHandler: createJsonErrorResponseHandler({ errorSchema: deepInfraErrorSchema, errorToMessage: (error) => error.detail.error }), successfulResponseHandler: createJsonResponseHandler( deepInfraImageResponseSchema ), abortSignal, fetch: this.config.fetch }); return { images: response.images.map( (image) => image.replace(/^data:image\/\w+;base64,/, "") ), warnings, response: { timestamp: currentDate, modelId: this.modelId, headers: responseHeaders } }; } getEditUrl() { const baseUrl = this.config.baseURL.replace("/inference", "/openai"); return `${baseUrl}/images/edits`; } }; var deepInfraErrorSchema = z.object({ detail: z.object({ error: z.string() }) }); var deepInfraImageResponseSchema = z.object({ images: z.array(z.string()) }); var deepInfraEditErrorSchema = z.object({ error: z.object({ message: z.string() }).optional() }); var deepInfraEditResponseSchema = z.object({ data: z.array(z.object({ b64_json: z.string() })) }); async function fileToBlob(file) { if (file.type === "url") { return downloadBlob(file.url); } const data = file.data instanceof Uint8Array ? file.data : convertBase64ToUint8Array(file.data); return new Blob([data], { type: file.mediaType }); } // src/deepinfra-chat-language-model.ts import { OpenAICompatibleChatLanguageModel } from "@ai-sdk/openai-compatible"; var DeepInfraChatLanguageModel = class extends OpenAICompatibleChatLanguageModel { constructor(modelId, config) { super(modelId, config); } /** * Fixes incorrect token usage for Gemini/Gemma models from DeepInfra. * * DeepInfra's API returns completion_tokens that don't include reasoning_tokens * for Gemini/Gemma models, which violates the OpenAI-compatible spec. * According to the spec, completion_tokens should include reasoning_tokens. * * Example of incorrect data from DeepInfra: * { * "completion_tokens": 84, // text-only tokens * "completion_tokens_details": { * "reasoning_tokens": 1081 // reasoning tokens not included above * } * } * * This would result in negative text tokens: 84 - 1081 = -997 * * The fix: If reasoning_tokens > completion_tokens, add reasoning_tokens * to completion_tokens: 84 + 1081 = 1165 */ fixUsageForGeminiModels(usage) { var _a, _b; if (!usage || !((_a = usage.completion_tokens_details) == null ? void 0 : _a.reasoning_tokens)) { return usage; } const completionTokens = (_b = usage.completion_tokens) != null ? _b : 0; const reasoningTokens = usage.completion_tokens_details.reasoning_tokens; if (reasoningTokens > completionTokens) { const correctedCompletionTokens = completionTokens + reasoningTokens; return { ...usage, // Add reasoning_tokens to completion_tokens to get the correct total completion_tokens: correctedCompletionTokens, // Update total_tokens if present total_tokens: usage.total_tokens != null ? usage.total_tokens + reasoningTokens : void 0 }; } return usage; } async doGenerate(options) { var _a, _b, _c, _d, _e, _f, _g; const result = await super.doGenerate(options); if ((_a = result.usage) == null ? void 0 : _a.raw) { const fixedRawUsage = this.fixUsageForGeminiModels(result.usage.raw); if (fixedRawUsage !== result.usage.raw) { const promptTokens = (_b = fixedRawUsage.prompt_tokens) != null ? _b : 0; const completionTokens = (_c = fixedRawUsage.completion_tokens) != null ? _c : 0; const cacheReadTokens = (_e = (_d = fixedRawUsage.prompt_tokens_details) == null ? void 0 : _d.cached_tokens) != null ? _e : 0; const reasoningTokens = (_g = (_f = fixedRawUsage.completion_tokens_details) == null ? void 0 : _f.reasoning_tokens) != null ? _g : 0; return { ...result, usage: { inputTokens: { total: promptTokens, noCache: promptTokens - cacheReadTokens, cacheRead: cacheReadTokens, cacheWrite: void 0 }, outputTokens: { total: completionTokens, text: completionTokens - reasoningTokens, reasoning: reasoningTokens }, raw: fixedRawUsage } }; } } return result; } async doStream(options) { const result = await super.doStream(options); const originalStream = result.stream; const fixUsage = this.fixUsageForGeminiModels.bind(this); const transformedStream = new ReadableStream({ async start(controller) { var _a, _b, _c, _d, _e, _f, _g; const reader = originalStream.getReader(); try { while (true) { const { done, value } = await reader.read(); if (done) break; if (value.type === "finish" && ((_a = value.usage) == null ? void 0 : _a.raw)) { const fixedRawUsage = fixUsage(value.usage.raw); if (fixedRawUsage !== value.usage.raw) { const promptTokens = (_b = fixedRawUsage.prompt_tokens) != null ? _b : 0; const completionTokens = (_c = fixedRawUsage.completion_tokens) != null ? _c : 0; const cacheReadTokens = (_e = (_d = fixedRawUsage.prompt_tokens_details) == null ? void 0 : _d.cached_tokens) != null ? _e : 0; const reasoningTokens = (_g = (_f = fixedRawUsage.completion_tokens_details) == null ? void 0 : _f.reasoning_tokens) != null ? _g : 0; controller.enqueue({ ...value, usage: { inputTokens: { total: promptTokens, noCache: promptTokens - cacheReadTokens, cacheRead: cacheReadTokens, cacheWrite: void 0 }, outputTokens: { total: completionTokens, text: completionTokens - reasoningTokens, reasoning: reasoningTokens }, raw: fixedRawUsage } }); } else { controller.enqueue(value); } } else { controller.enqueue(value); } } controller.close(); } catch (error) { controller.error(error); } } }); return { ...result, stream: transformedStream }; } }; // src/version.ts var VERSION = true ? "2.0.52" : "0.0.0-test"; // src/deepinfra-provider.ts function createDeepInfra(options = {}) { var _a; const baseURL = withoutTrailingSlash( (_a = options.baseURL) != null ? _a : "https://api.deepinfra.com/v1" ); const getHeaders = () => withUserAgentSuffix( { Authorization: `Bearer ${loadApiKey({ apiKey: options.apiKey, environmentVariableName: "DEEPINFRA_API_KEY", description: "DeepInfra's API key" })}`, ...options.headers }, `ai-sdk/deepinfra/${VERSION}` ); const getCommonModelConfig = (modelType) => ({ provider: `deepinfra.${modelType}`, url: ({ path }) => `${baseURL}/openai${path}`, headers: getHeaders, fetch: options.fetch }); const createChatModel = (modelId) => { return new DeepInfraChatLanguageModel( modelId, getCommonModelConfig("chat") ); }; const createCompletionModel = (modelId) => new OpenAICompatibleCompletionLanguageModel( modelId, getCommonModelConfig("completion") ); const createEmbeddingModel = (modelId) => new OpenAICompatibleEmbeddingModel( modelId, getCommonModelConfig("embedding") ); const createImageModel = (modelId) => new DeepInfraImageModel(modelId, { ...getCommonModelConfig("image"), baseURL: baseURL ? `${baseURL}/inference` : "https://api.deepinfra.com/v1/inference" }); const provider = (modelId) => createChatModel(modelId); provider.specificationVersion = "v3"; provider.completionModel = createCompletionModel; provider.chatModel = createChatModel; provider.image = createImageModel; provider.imageModel = createImageModel; provider.languageModel = createChatModel; provider.embeddingModel = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; return provider; } var deepinfra = createDeepInfra(); export { VERSION, createDeepInfra, deepinfra }; //# sourceMappingURL=index.mjs.map