@ai-sdk/deepinfra
Version:
The **[DeepInfra provider](https://ai-sdk.dev/providers/ai-sdk-providers/deepinfra)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the DeepInfra API, giving you access to models like Llama 3, Mixtral, and other state-of-th
162 lines (141 loc) • 4.44 kB
text/typescript
import {
LanguageModelV3,
EmbeddingModelV3,
ProviderV3,
ImageModelV3,
} from '@ai-sdk/provider';
import {
OpenAICompatibleChatLanguageModel,
OpenAICompatibleCompletionLanguageModel,
OpenAICompatibleEmbeddingModel,
} from '@ai-sdk/openai-compatible';
import {
FetchFunction,
loadApiKey,
withoutTrailingSlash,
withUserAgentSuffix,
} from '@ai-sdk/provider-utils';
import { DeepInfraChatModelId } from './deepinfra-chat-options';
import { DeepInfraEmbeddingModelId } from './deepinfra-embedding-options';
import { DeepInfraCompletionModelId } from './deepinfra-completion-options';
import { DeepInfraImageModelId } from './deepinfra-image-settings';
import { DeepInfraImageModel } from './deepinfra-image-model';
import { VERSION } from './version';
export interface DeepInfraProviderSettings {
/**
DeepInfra API key.
*/
apiKey?: string;
/**
Base URL for the API calls.
*/
baseURL?: string;
/**
Custom headers to include in the requests.
*/
headers?: Record<string, string>;
/**
Custom fetch implementation. You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.
*/
fetch?: FetchFunction;
}
export interface DeepInfraProvider extends ProviderV3 {
/**
Creates a model for text generation.
*/
(modelId: DeepInfraChatModelId): LanguageModelV3;
/**
Creates a chat model for text generation.
*/
chatModel(modelId: DeepInfraChatModelId): LanguageModelV3;
/**
Creates a model for image generation.
*/
image(modelId: DeepInfraImageModelId): ImageModelV3;
/**
Creates a model for image generation.
*/
imageModel(modelId: DeepInfraImageModelId): ImageModelV3;
/**
Creates a chat model for text generation.
*/
languageModel(modelId: DeepInfraChatModelId): LanguageModelV3;
/**
Creates a completion model for text generation.
*/
completionModel(modelId: DeepInfraCompletionModelId): LanguageModelV3;
/**
Creates a embedding model for text generation.
*/
embeddingModel(modelId: DeepInfraEmbeddingModelId): EmbeddingModelV3;
/**
* @deprecated Use `embeddingModel` instead.
*/
textEmbeddingModel(modelId: DeepInfraEmbeddingModelId): EmbeddingModelV3;
}
export function createDeepInfra(
options: DeepInfraProviderSettings = {},
): DeepInfraProvider {
const baseURL = withoutTrailingSlash(
options.baseURL ?? 'https://api.deepinfra.com/v1',
);
const getHeaders = () =>
withUserAgentSuffix(
{
Authorization: `Bearer ${loadApiKey({
apiKey: options.apiKey,
environmentVariableName: 'DEEPINFRA_API_KEY',
description: "DeepInfra's API key",
})}`,
...options.headers,
},
`ai-sdk/deepinfra/${VERSION}`,
);
interface CommonModelConfig {
provider: string;
url: ({ path }: { path: string }) => string;
headers: () => Record<string, string>;
fetch?: FetchFunction;
}
const getCommonModelConfig = (modelType: string): CommonModelConfig => ({
provider: `deepinfra.${modelType}`,
url: ({ path }) => `${baseURL}/openai${path}`,
headers: getHeaders,
fetch: options.fetch,
});
const createChatModel = (modelId: DeepInfraChatModelId) => {
return new OpenAICompatibleChatLanguageModel(
modelId,
getCommonModelConfig('chat'),
);
};
const createCompletionModel = (modelId: DeepInfraCompletionModelId) =>
new OpenAICompatibleCompletionLanguageModel(
modelId,
getCommonModelConfig('completion'),
);
const createEmbeddingModel = (modelId: DeepInfraEmbeddingModelId) =>
new OpenAICompatibleEmbeddingModel(
modelId,
getCommonModelConfig('embedding'),
);
const createImageModel = (modelId: DeepInfraImageModelId) =>
new DeepInfraImageModel(modelId, {
...getCommonModelConfig('image'),
baseURL: baseURL
? `${baseURL}/inference`
: 'https://api.deepinfra.com/v1/inference',
});
const provider = (modelId: DeepInfraChatModelId) => createChatModel(modelId);
provider.specificationVersion = 'v3' as const;
provider.completionModel = createCompletionModel;
provider.chatModel = createChatModel;
provider.image = createImageModel;
provider.imageModel = createImageModel;
provider.languageModel = createChatModel;
provider.embeddingModel = createEmbeddingModel;
provider.textEmbeddingModel = createEmbeddingModel;
return provider;
}
export const deepinfra = createDeepInfra();