@ai-sdk/deepinfra
Version:
The **[DeepInfra provider](https://ai-sdk.dev/providers/ai-sdk-providers/deepinfra)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the DeepInfra API, giving you access to models like Llama 3, Mixtral, and other state-of-th
357 lines (351 loc) • 13.1 kB
JavaScript
;
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name in all)
__defProp(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
// src/index.ts
var index_exports = {};
__export(index_exports, {
VERSION: () => VERSION,
createDeepInfra: () => createDeepInfra,
deepinfra: () => deepinfra
});
module.exports = __toCommonJS(index_exports);
// src/deepinfra-provider.ts
var import_openai_compatible2 = require("@ai-sdk/openai-compatible");
var import_provider_utils2 = require("@ai-sdk/provider-utils");
// src/deepinfra-image-model.ts
var import_provider_utils = require("@ai-sdk/provider-utils");
var import_v4 = require("zod/v4");
var DeepInfraImageModel = class {
constructor(modelId, config) {
this.modelId = modelId;
this.config = config;
this.specificationVersion = "v3";
this.maxImagesPerCall = 1;
}
get provider() {
return this.config.provider;
}
async doGenerate({
prompt,
n,
size,
aspectRatio,
seed,
providerOptions,
headers,
abortSignal,
files,
mask
}) {
var _a, _b, _c, _d, _e;
const warnings = [];
const currentDate = (_c = (_b = (_a = this.config._internal) == null ? void 0 : _a.currentDate) == null ? void 0 : _b.call(_a)) != null ? _c : /* @__PURE__ */ new Date();
if (files != null && files.length > 0) {
const { value: response2, responseHeaders: responseHeaders2 } = await (0, import_provider_utils.postFormDataToApi)({
url: this.getEditUrl(),
headers: (0, import_provider_utils.combineHeaders)(this.config.headers(), headers),
formData: (0, import_provider_utils.convertToFormData)(
{
model: this.modelId,
prompt,
image: await Promise.all(files.map((file) => fileToBlob(file))),
mask: mask != null ? await fileToBlob(mask) : void 0,
n,
size,
...(_d = providerOptions.deepinfra) != null ? _d : {}
},
{ useArrayBrackets: false }
),
failedResponseHandler: (0, import_provider_utils.createJsonErrorResponseHandler)({
errorSchema: deepInfraEditErrorSchema,
errorToMessage: (error) => {
var _a2, _b2;
return (_b2 = (_a2 = error.error) == null ? void 0 : _a2.message) != null ? _b2 : "Unknown error";
}
}),
successfulResponseHandler: (0, import_provider_utils.createJsonResponseHandler)(
deepInfraEditResponseSchema
),
abortSignal,
fetch: this.config.fetch
});
return {
images: response2.data.map((item) => item.b64_json),
warnings,
response: {
timestamp: currentDate,
modelId: this.modelId,
headers: responseHeaders2
}
};
}
const splitSize = size == null ? void 0 : size.split("x");
const { value: response, responseHeaders } = await (0, import_provider_utils.postJsonToApi)({
url: `${this.config.baseURL}/${this.modelId}`,
headers: (0, import_provider_utils.combineHeaders)(this.config.headers(), headers),
body: {
prompt,
num_images: n,
...aspectRatio && { aspect_ratio: aspectRatio },
...splitSize && { width: splitSize[0], height: splitSize[1] },
...seed != null && { seed },
...(_e = providerOptions.deepinfra) != null ? _e : {}
},
failedResponseHandler: (0, import_provider_utils.createJsonErrorResponseHandler)({
errorSchema: deepInfraErrorSchema,
errorToMessage: (error) => error.detail.error
}),
successfulResponseHandler: (0, import_provider_utils.createJsonResponseHandler)(
deepInfraImageResponseSchema
),
abortSignal,
fetch: this.config.fetch
});
return {
images: response.images.map(
(image) => image.replace(/^data:image\/\w+;base64,/, "")
),
warnings,
response: {
timestamp: currentDate,
modelId: this.modelId,
headers: responseHeaders
}
};
}
getEditUrl() {
const baseUrl = this.config.baseURL.replace("/inference", "/openai");
return `${baseUrl}/images/edits`;
}
};
var deepInfraErrorSchema = import_v4.z.object({
detail: import_v4.z.object({
error: import_v4.z.string()
})
});
var deepInfraImageResponseSchema = import_v4.z.object({
images: import_v4.z.array(import_v4.z.string())
});
var deepInfraEditErrorSchema = import_v4.z.object({
error: import_v4.z.object({
message: import_v4.z.string()
}).optional()
});
var deepInfraEditResponseSchema = import_v4.z.object({
data: import_v4.z.array(import_v4.z.object({ b64_json: import_v4.z.string() }))
});
async function fileToBlob(file) {
if (file.type === "url") {
return (0, import_provider_utils.downloadBlob)(file.url);
}
const data = file.data instanceof Uint8Array ? file.data : (0, import_provider_utils.convertBase64ToUint8Array)(file.data);
return new Blob([data], { type: file.mediaType });
}
// src/deepinfra-chat-language-model.ts
var import_openai_compatible = require("@ai-sdk/openai-compatible");
var DeepInfraChatLanguageModel = class extends import_openai_compatible.OpenAICompatibleChatLanguageModel {
constructor(modelId, config) {
super(modelId, config);
}
/**
* Fixes incorrect token usage for Gemini/Gemma models from DeepInfra.
*
* DeepInfra's API returns completion_tokens that don't include reasoning_tokens
* for Gemini/Gemma models, which violates the OpenAI-compatible spec.
* According to the spec, completion_tokens should include reasoning_tokens.
*
* Example of incorrect data from DeepInfra:
* {
* "completion_tokens": 84, // text-only tokens
* "completion_tokens_details": {
* "reasoning_tokens": 1081 // reasoning tokens not included above
* }
* }
*
* This would result in negative text tokens: 84 - 1081 = -997
*
* The fix: If reasoning_tokens > completion_tokens, add reasoning_tokens
* to completion_tokens: 84 + 1081 = 1165
*/
fixUsageForGeminiModels(usage) {
var _a, _b;
if (!usage || !((_a = usage.completion_tokens_details) == null ? void 0 : _a.reasoning_tokens)) {
return usage;
}
const completionTokens = (_b = usage.completion_tokens) != null ? _b : 0;
const reasoningTokens = usage.completion_tokens_details.reasoning_tokens;
if (reasoningTokens > completionTokens) {
const correctedCompletionTokens = completionTokens + reasoningTokens;
return {
...usage,
// Add reasoning_tokens to completion_tokens to get the correct total
completion_tokens: correctedCompletionTokens,
// Update total_tokens if present
total_tokens: usage.total_tokens != null ? usage.total_tokens + reasoningTokens : void 0
};
}
return usage;
}
async doGenerate(options) {
var _a, _b, _c, _d, _e, _f, _g;
const result = await super.doGenerate(options);
if ((_a = result.usage) == null ? void 0 : _a.raw) {
const fixedRawUsage = this.fixUsageForGeminiModels(result.usage.raw);
if (fixedRawUsage !== result.usage.raw) {
const promptTokens = (_b = fixedRawUsage.prompt_tokens) != null ? _b : 0;
const completionTokens = (_c = fixedRawUsage.completion_tokens) != null ? _c : 0;
const cacheReadTokens = (_e = (_d = fixedRawUsage.prompt_tokens_details) == null ? void 0 : _d.cached_tokens) != null ? _e : 0;
const reasoningTokens = (_g = (_f = fixedRawUsage.completion_tokens_details) == null ? void 0 : _f.reasoning_tokens) != null ? _g : 0;
return {
...result,
usage: {
inputTokens: {
total: promptTokens,
noCache: promptTokens - cacheReadTokens,
cacheRead: cacheReadTokens,
cacheWrite: void 0
},
outputTokens: {
total: completionTokens,
text: completionTokens - reasoningTokens,
reasoning: reasoningTokens
},
raw: fixedRawUsage
}
};
}
}
return result;
}
async doStream(options) {
const result = await super.doStream(options);
const originalStream = result.stream;
const fixUsage = this.fixUsageForGeminiModels.bind(this);
const transformedStream = new ReadableStream({
async start(controller) {
var _a, _b, _c, _d, _e, _f, _g;
const reader = originalStream.getReader();
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
if (value.type === "finish" && ((_a = value.usage) == null ? void 0 : _a.raw)) {
const fixedRawUsage = fixUsage(value.usage.raw);
if (fixedRawUsage !== value.usage.raw) {
const promptTokens = (_b = fixedRawUsage.prompt_tokens) != null ? _b : 0;
const completionTokens = (_c = fixedRawUsage.completion_tokens) != null ? _c : 0;
const cacheReadTokens = (_e = (_d = fixedRawUsage.prompt_tokens_details) == null ? void 0 : _d.cached_tokens) != null ? _e : 0;
const reasoningTokens = (_g = (_f = fixedRawUsage.completion_tokens_details) == null ? void 0 : _f.reasoning_tokens) != null ? _g : 0;
controller.enqueue({
...value,
usage: {
inputTokens: {
total: promptTokens,
noCache: promptTokens - cacheReadTokens,
cacheRead: cacheReadTokens,
cacheWrite: void 0
},
outputTokens: {
total: completionTokens,
text: completionTokens - reasoningTokens,
reasoning: reasoningTokens
},
raw: fixedRawUsage
}
});
} else {
controller.enqueue(value);
}
} else {
controller.enqueue(value);
}
}
controller.close();
} catch (error) {
controller.error(error);
}
}
});
return {
...result,
stream: transformedStream
};
}
};
// src/version.ts
var VERSION = true ? "2.0.52" : "0.0.0-test";
// src/deepinfra-provider.ts
function createDeepInfra(options = {}) {
var _a;
const baseURL = (0, import_provider_utils2.withoutTrailingSlash)(
(_a = options.baseURL) != null ? _a : "https://api.deepinfra.com/v1"
);
const getHeaders = () => (0, import_provider_utils2.withUserAgentSuffix)(
{
Authorization: `Bearer ${(0, import_provider_utils2.loadApiKey)({
apiKey: options.apiKey,
environmentVariableName: "DEEPINFRA_API_KEY",
description: "DeepInfra's API key"
})}`,
...options.headers
},
`ai-sdk/deepinfra/${VERSION}`
);
const getCommonModelConfig = (modelType) => ({
provider: `deepinfra.${modelType}`,
url: ({ path }) => `${baseURL}/openai${path}`,
headers: getHeaders,
fetch: options.fetch
});
const createChatModel = (modelId) => {
return new DeepInfraChatLanguageModel(
modelId,
getCommonModelConfig("chat")
);
};
const createCompletionModel = (modelId) => new import_openai_compatible2.OpenAICompatibleCompletionLanguageModel(
modelId,
getCommonModelConfig("completion")
);
const createEmbeddingModel = (modelId) => new import_openai_compatible2.OpenAICompatibleEmbeddingModel(
modelId,
getCommonModelConfig("embedding")
);
const createImageModel = (modelId) => new DeepInfraImageModel(modelId, {
...getCommonModelConfig("image"),
baseURL: baseURL ? `${baseURL}/inference` : "https://api.deepinfra.com/v1/inference"
});
const provider = (modelId) => createChatModel(modelId);
provider.specificationVersion = "v3";
provider.completionModel = createCompletionModel;
provider.chatModel = createChatModel;
provider.image = createImageModel;
provider.imageModel = createImageModel;
provider.languageModel = createChatModel;
provider.embeddingModel = createEmbeddingModel;
provider.textEmbeddingModel = createEmbeddingModel;
return provider;
}
var deepinfra = createDeepInfra();
// Annotate the CommonJS export names for ESM import in node:
0 && (module.exports = {
VERSION,
createDeepInfra,
deepinfra
});
//# sourceMappingURL=index.js.map