@ai-sdk/deepinfra
Version:
The **[DeepInfra provider](https://ai-sdk.dev/providers/ai-sdk-providers/deepinfra)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the DeepInfra API, giving you access to models like Llama 3, Mixtral, and other state-of-th
175 lines (157 loc) • 5.88 kB
text/typescript
import type {
LanguageModelV3CallOptions,
LanguageModelV3GenerateResult,
LanguageModelV3StreamResult,
} from '@ai-sdk/provider';
import { OpenAICompatibleChatLanguageModel } from '@ai-sdk/openai-compatible';
import type { FetchFunction } from '@ai-sdk/provider-utils';
type DeepInfraChatConfig = {
provider: string;
url: (options: { path: string; modelId?: string }) => string;
headers: () => Record<string, string | undefined>;
fetch?: FetchFunction;
};
export class DeepInfraChatLanguageModel extends OpenAICompatibleChatLanguageModel {
constructor(modelId: string, config: DeepInfraChatConfig) {
super(modelId, config);
}
/**
* Fixes incorrect token usage for Gemini/Gemma models from DeepInfra.
*
* DeepInfra's API returns completion_tokens that don't include reasoning_tokens
* for Gemini/Gemma models, which violates the OpenAI-compatible spec.
* According to the spec, completion_tokens should include reasoning_tokens.
*
* Example of incorrect data from DeepInfra:
* {
* "completion_tokens": 84, // text-only tokens
* "completion_tokens_details": {
* "reasoning_tokens": 1081 // reasoning tokens not included above
* }
* }
*
* This would result in negative text tokens: 84 - 1081 = -997
*
* The fix: If reasoning_tokens > completion_tokens, add reasoning_tokens
* to completion_tokens: 84 + 1081 = 1165
*/
private fixUsageForGeminiModels(usage: any): typeof usage {
if (!usage || !usage.completion_tokens_details?.reasoning_tokens) {
return usage;
}
const completionTokens = usage.completion_tokens ?? 0;
const reasoningTokens = usage.completion_tokens_details.reasoning_tokens;
// If reasoning tokens exceed completion tokens, the API data is incorrect
// DeepInfra is returning only text tokens in completion_tokens, not including reasoning
if (reasoningTokens > completionTokens) {
const correctedCompletionTokens = completionTokens + reasoningTokens;
return {
...usage,
// Add reasoning_tokens to completion_tokens to get the correct total
completion_tokens: correctedCompletionTokens,
// Update total_tokens if present
total_tokens:
usage.total_tokens != null
? usage.total_tokens + reasoningTokens
: undefined,
};
}
return usage;
}
async doGenerate(
options: LanguageModelV3CallOptions,
): Promise<LanguageModelV3GenerateResult> {
const result = await super.doGenerate(options);
// Fix usage if needed
if (result.usage?.raw) {
const fixedRawUsage = this.fixUsageForGeminiModels(result.usage.raw);
if (fixedRawUsage !== result.usage.raw) {
// Recalculate usage with fixed data
const promptTokens = fixedRawUsage.prompt_tokens ?? 0;
const completionTokens = fixedRawUsage.completion_tokens ?? 0;
const cacheReadTokens =
fixedRawUsage.prompt_tokens_details?.cached_tokens ?? 0;
const reasoningTokens =
fixedRawUsage.completion_tokens_details?.reasoning_tokens ?? 0;
return {
...result,
usage: {
inputTokens: {
total: promptTokens,
noCache: promptTokens - cacheReadTokens,
cacheRead: cacheReadTokens,
cacheWrite: undefined,
},
outputTokens: {
total: completionTokens,
text: completionTokens - reasoningTokens,
reasoning: reasoningTokens,
},
raw: fixedRawUsage,
},
};
}
}
return result;
}
async doStream(
options: LanguageModelV3CallOptions,
): Promise<LanguageModelV3StreamResult> {
const result = await super.doStream(options);
// Wrap the stream to fix usage in the final chunk
const originalStream = result.stream;
const fixUsage = this.fixUsageForGeminiModels.bind(this);
const transformedStream = new ReadableStream({
async start(controller) {
const reader = originalStream.getReader();
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
// Fix usage in finish chunks
if (value.type === 'finish' && value.usage?.raw) {
const fixedRawUsage = fixUsage(value.usage.raw);
if (fixedRawUsage !== value.usage.raw) {
const promptTokens = fixedRawUsage.prompt_tokens ?? 0;
const completionTokens = fixedRawUsage.completion_tokens ?? 0;
const cacheReadTokens =
fixedRawUsage.prompt_tokens_details?.cached_tokens ?? 0;
const reasoningTokens =
fixedRawUsage.completion_tokens_details?.reasoning_tokens ??
0;
controller.enqueue({
...value,
usage: {
inputTokens: {
total: promptTokens,
noCache: promptTokens - cacheReadTokens,
cacheRead: cacheReadTokens,
cacheWrite: undefined,
},
outputTokens: {
total: completionTokens,
text: completionTokens - reasoningTokens,
reasoning: reasoningTokens,
},
raw: fixedRawUsage,
},
});
} else {
controller.enqueue(value);
}
} else {
controller.enqueue(value);
}
}
controller.close();
} catch (error) {
controller.error(error);
}
},
});
return {
...result,
stream: transformedStream,
};
}
}