UNPKG

@ai-sdk/deepinfra

Version:

The **[DeepInfra provider](https://ai-sdk.dev/providers/ai-sdk-providers/deepinfra)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the DeepInfra API, giving you access to models like Llama 3, Mixtral, and other state-of-th

213 lines (210 loc) 6.47 kB
// src/deepinfra-provider.ts import { OpenAICompatibleChatLanguageModel, OpenAICompatibleCompletionLanguageModel, OpenAICompatibleEmbeddingModel } from "@ai-sdk/openai-compatible"; import { loadApiKey, withoutTrailingSlash, withUserAgentSuffix } from "@ai-sdk/provider-utils"; // src/deepinfra-image-model.ts import { combineHeaders, convertBase64ToUint8Array, convertToFormData, createJsonErrorResponseHandler, createJsonResponseHandler, downloadBlob, postFormDataToApi, postJsonToApi } from "@ai-sdk/provider-utils"; import { z } from "zod/v4"; var DeepInfraImageModel = class { constructor(modelId, config) { this.modelId = modelId; this.config = config; this.specificationVersion = "v3"; this.maxImagesPerCall = 1; } get provider() { return this.config.provider; } async doGenerate({ prompt, n, size, aspectRatio, seed, providerOptions, headers, abortSignal, files, mask }) { var _a, _b, _c, _d, _e; const warnings = []; const currentDate = (_c = (_b = (_a = this.config._internal) == null ? void 0 : _a.currentDate) == null ? void 0 : _b.call(_a)) != null ? _c : /* @__PURE__ */ new Date(); if (files != null && files.length > 0) { const { value: response2, responseHeaders: responseHeaders2 } = await postFormDataToApi({ url: this.getEditUrl(), headers: combineHeaders(this.config.headers(), headers), formData: convertToFormData( { model: this.modelId, prompt, image: await Promise.all(files.map((file) => fileToBlob(file))), mask: mask != null ? await fileToBlob(mask) : void 0, n, size, ...(_d = providerOptions.deepinfra) != null ? _d : {} }, { useArrayBrackets: false } ), failedResponseHandler: createJsonErrorResponseHandler({ errorSchema: deepInfraEditErrorSchema, errorToMessage: (error) => { var _a2, _b2; return (_b2 = (_a2 = error.error) == null ? void 0 : _a2.message) != null ? _b2 : "Unknown error"; } }), successfulResponseHandler: createJsonResponseHandler( deepInfraEditResponseSchema ), abortSignal, fetch: this.config.fetch }); return { images: response2.data.map((item) => item.b64_json), warnings, response: { timestamp: currentDate, modelId: this.modelId, headers: responseHeaders2 } }; } const splitSize = size == null ? void 0 : size.split("x"); const { value: response, responseHeaders } = await postJsonToApi({ url: `${this.config.baseURL}/${this.modelId}`, headers: combineHeaders(this.config.headers(), headers), body: { prompt, num_images: n, ...aspectRatio && { aspect_ratio: aspectRatio }, ...splitSize && { width: splitSize[0], height: splitSize[1] }, ...seed != null && { seed }, ...(_e = providerOptions.deepinfra) != null ? _e : {} }, failedResponseHandler: createJsonErrorResponseHandler({ errorSchema: deepInfraErrorSchema, errorToMessage: (error) => error.detail.error }), successfulResponseHandler: createJsonResponseHandler( deepInfraImageResponseSchema ), abortSignal, fetch: this.config.fetch }); return { images: response.images.map( (image) => image.replace(/^data:image\/\w+;base64,/, "") ), warnings, response: { timestamp: currentDate, modelId: this.modelId, headers: responseHeaders } }; } getEditUrl() { const baseUrl = this.config.baseURL.replace("/inference", "/openai"); return `${baseUrl}/images/edits`; } }; var deepInfraErrorSchema = z.object({ detail: z.object({ error: z.string() }) }); var deepInfraImageResponseSchema = z.object({ images: z.array(z.string()) }); var deepInfraEditErrorSchema = z.object({ error: z.object({ message: z.string() }).optional() }); var deepInfraEditResponseSchema = z.object({ data: z.array(z.object({ b64_json: z.string() })) }); async function fileToBlob(file) { if (file.type === "url") { return downloadBlob(file.url); } const data = file.data instanceof Uint8Array ? file.data : convertBase64ToUint8Array(file.data); return new Blob([data], { type: file.mediaType }); } // src/version.ts var VERSION = true ? "2.0.19" : "0.0.0-test"; // src/deepinfra-provider.ts function createDeepInfra(options = {}) { var _a; const baseURL = withoutTrailingSlash( (_a = options.baseURL) != null ? _a : "https://api.deepinfra.com/v1" ); const getHeaders = () => withUserAgentSuffix( { Authorization: `Bearer ${loadApiKey({ apiKey: options.apiKey, environmentVariableName: "DEEPINFRA_API_KEY", description: "DeepInfra's API key" })}`, ...options.headers }, `ai-sdk/deepinfra/${VERSION}` ); const getCommonModelConfig = (modelType) => ({ provider: `deepinfra.${modelType}`, url: ({ path }) => `${baseURL}/openai${path}`, headers: getHeaders, fetch: options.fetch }); const createChatModel = (modelId) => { return new OpenAICompatibleChatLanguageModel( modelId, getCommonModelConfig("chat") ); }; const createCompletionModel = (modelId) => new OpenAICompatibleCompletionLanguageModel( modelId, getCommonModelConfig("completion") ); const createEmbeddingModel = (modelId) => new OpenAICompatibleEmbeddingModel( modelId, getCommonModelConfig("embedding") ); const createImageModel = (modelId) => new DeepInfraImageModel(modelId, { ...getCommonModelConfig("image"), baseURL: baseURL ? `${baseURL}/inference` : "https://api.deepinfra.com/v1/inference" }); const provider = (modelId) => createChatModel(modelId); provider.specificationVersion = "v3"; provider.completionModel = createCompletionModel; provider.chatModel = createChatModel; provider.image = createImageModel; provider.imageModel = createImageModel; provider.languageModel = createChatModel; provider.embeddingModel = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; return provider; } var deepinfra = createDeepInfra(); export { VERSION, createDeepInfra, deepinfra }; //# sourceMappingURL=index.mjs.map