UNPKG

@ai-sdk/mistral

Version:

The **[Mistral provider](https://ai-sdk.dev/providers/ai-sdk-providers/mistral)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the Mistral chat API.

818 lines (807 loc) 24.2 kB
// src/mistral-provider.ts import { NoSuchModelError } from "@ai-sdk/provider"; import { loadApiKey, withoutTrailingSlash } from "@ai-sdk/provider-utils"; // src/mistral-chat-language-model.ts import { combineHeaders, createEventSourceResponseHandler, createJsonResponseHandler, generateId, injectJsonInstructionIntoMessages, parseProviderOptions, postJsonToApi } from "@ai-sdk/provider-utils"; import { z as z3 } from "zod/v4"; // src/convert-to-mistral-chat-messages.ts import { UnsupportedFunctionalityError } from "@ai-sdk/provider"; import { convertToBase64 } from "@ai-sdk/provider-utils"; function convertToMistralChatMessages(prompt) { const messages = []; for (let i = 0; i < prompt.length; i++) { const { role, content } = prompt[i]; const isLastMessage = i === prompt.length - 1; switch (role) { case "system": { messages.push({ role: "system", content }); break; } case "user": { messages.push({ role: "user", content: content.map((part) => { switch (part.type) { case "text": { return { type: "text", text: part.text }; } case "file": { if (part.mediaType.startsWith("image/")) { const mediaType = part.mediaType === "image/*" ? "image/jpeg" : part.mediaType; return { type: "image_url", image_url: part.data instanceof URL ? part.data.toString() : `data:${mediaType};base64,${convertToBase64(part.data)}` }; } else if (part.mediaType === "application/pdf") { return { type: "document_url", document_url: part.data.toString() }; } else { throw new UnsupportedFunctionalityError({ functionality: "Only images and PDF file parts are supported" }); } } } }) }); break; } case "assistant": { let text = ""; const toolCalls = []; for (const part of content) { switch (part.type) { case "text": { text += part.text; break; } case "tool-call": { toolCalls.push({ id: part.toolCallId, type: "function", function: { name: part.toolName, arguments: JSON.stringify(part.input) } }); break; } case "reasoning": { text += part.text; break; } default: { throw new Error( `Unsupported content type in assistant message: ${part.type}` ); } } } messages.push({ role: "assistant", content: text, prefix: isLastMessage ? true : void 0, tool_calls: toolCalls.length > 0 ? toolCalls : void 0 }); break; } case "tool": { for (const toolResponse of content) { const output = toolResponse.output; let contentValue; switch (output.type) { case "text": case "error-text": contentValue = output.value; break; case "content": case "json": case "error-json": contentValue = JSON.stringify(output.value); break; } messages.push({ role: "tool", name: toolResponse.toolName, tool_call_id: toolResponse.toolCallId, content: contentValue }); } break; } default: { const _exhaustiveCheck = role; throw new Error(`Unsupported role: ${_exhaustiveCheck}`); } } } return messages; } // src/get-response-metadata.ts function getResponseMetadata({ id, model, created }) { return { id: id != null ? id : void 0, modelId: model != null ? model : void 0, timestamp: created != null ? new Date(created * 1e3) : void 0 }; } // src/map-mistral-finish-reason.ts function mapMistralFinishReason(finishReason) { switch (finishReason) { case "stop": return "stop"; case "length": case "model_length": return "length"; case "tool_calls": return "tool-calls"; default: return "unknown"; } } // src/mistral-chat-options.ts import { z } from "zod/v4"; var mistralLanguageModelOptions = z.object({ /** Whether to inject a safety prompt before all conversations. Defaults to `false`. */ safePrompt: z.boolean().optional(), documentImageLimit: z.number().optional(), documentPageLimit: z.number().optional(), /** * Whether to use structured outputs. * * @default true */ structuredOutputs: z.boolean().optional(), /** * Whether to use strict JSON schema validation. * * @default false */ strictJsonSchema: z.boolean().optional() }); // src/mistral-error.ts import { createJsonErrorResponseHandler } from "@ai-sdk/provider-utils"; import { z as z2 } from "zod/v4"; var mistralErrorDataSchema = z2.object({ object: z2.literal("error"), message: z2.string(), type: z2.string(), param: z2.string().nullable(), code: z2.string().nullable() }); var mistralFailedResponseHandler = createJsonErrorResponseHandler({ errorSchema: mistralErrorDataSchema, errorToMessage: (data) => data.message }); // src/mistral-prepare-tools.ts import { UnsupportedFunctionalityError as UnsupportedFunctionalityError2 } from "@ai-sdk/provider"; function prepareTools({ tools, toolChoice }) { tools = (tools == null ? void 0 : tools.length) ? tools : void 0; const toolWarnings = []; if (tools == null) { return { tools: void 0, toolChoice: void 0, toolWarnings }; } const mistralTools = []; for (const tool of tools) { if (tool.type === "provider-defined") { toolWarnings.push({ type: "unsupported-tool", tool }); } else { mistralTools.push({ type: "function", function: { name: tool.name, description: tool.description, parameters: tool.inputSchema } }); } } if (toolChoice == null) { return { tools: mistralTools, toolChoice: void 0, toolWarnings }; } const type = toolChoice.type; switch (type) { case "auto": case "none": return { tools: mistralTools, toolChoice: type, toolWarnings }; case "required": return { tools: mistralTools, toolChoice: "any", toolWarnings }; case "tool": return { tools: mistralTools.filter( (tool) => tool.function.name === toolChoice.toolName ), toolChoice: "any", toolWarnings }; default: { const _exhaustiveCheck = type; throw new UnsupportedFunctionalityError2({ functionality: `tool choice type: ${_exhaustiveCheck}` }); } } } // src/mistral-chat-language-model.ts var MistralChatLanguageModel = class { constructor(modelId, config) { this.specificationVersion = "v2"; this.supportedUrls = { "application/pdf": [/^https:\/\/.*$/] }; var _a; this.modelId = modelId; this.config = config; this.generateId = (_a = config.generateId) != null ? _a : generateId; } get provider() { return this.config.provider; } async getArgs({ prompt, maxOutputTokens, temperature, topP, topK, frequencyPenalty, presencePenalty, stopSequences, responseFormat, seed, providerOptions, tools, toolChoice }) { var _a, _b, _c, _d; const warnings = []; const options = (_a = await parseProviderOptions({ provider: "mistral", providerOptions, schema: mistralLanguageModelOptions })) != null ? _a : {}; if (topK != null) { warnings.push({ type: "unsupported-setting", setting: "topK" }); } if (frequencyPenalty != null) { warnings.push({ type: "unsupported-setting", setting: "frequencyPenalty" }); } if (presencePenalty != null) { warnings.push({ type: "unsupported-setting", setting: "presencePenalty" }); } if (stopSequences != null) { warnings.push({ type: "unsupported-setting", setting: "stopSequences" }); } const structuredOutputs = (_b = options.structuredOutputs) != null ? _b : true; const strictJsonSchema = (_c = options.strictJsonSchema) != null ? _c : false; if ((responseFormat == null ? void 0 : responseFormat.type) === "json" && !(responseFormat == null ? void 0 : responseFormat.schema)) { prompt = injectJsonInstructionIntoMessages({ messages: prompt, schema: responseFormat.schema }); } const baseArgs = { // model id: model: this.modelId, // model specific settings: safe_prompt: options.safePrompt, // standardized settings: max_tokens: maxOutputTokens, temperature, top_p: topP, random_seed: seed, // response format: response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? structuredOutputs && (responseFormat == null ? void 0 : responseFormat.schema) != null ? { type: "json_schema", json_schema: { schema: responseFormat.schema, strict: strictJsonSchema, name: (_d = responseFormat.name) != null ? _d : "response", description: responseFormat.description } } : { type: "json_object" } : void 0, // mistral-specific provider options: document_image_limit: options.documentImageLimit, document_page_limit: options.documentPageLimit, // messages: messages: convertToMistralChatMessages(prompt) }; const { tools: mistralTools, toolChoice: mistralToolChoice, toolWarnings } = prepareTools({ tools, toolChoice }); return { args: { ...baseArgs, tools: mistralTools, tool_choice: mistralToolChoice }, warnings: [...warnings, ...toolWarnings] }; } async doGenerate(options) { const { args: body, warnings } = await this.getArgs(options); const { responseHeaders, value: response, rawValue: rawResponse } = await postJsonToApi({ url: `${this.config.baseURL}/chat/completions`, headers: combineHeaders(this.config.headers(), options.headers), body, failedResponseHandler: mistralFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( mistralChatResponseSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); const choice = response.choices[0]; const content = []; if (choice.message.content != null && Array.isArray(choice.message.content)) { for (const part of choice.message.content) { if (part.type === "thinking") { const reasoningText = extractReasoningContent(part.thinking); if (reasoningText.length > 0) { content.push({ type: "reasoning", text: reasoningText }); } } else if (part.type === "text") { if (part.text.length > 0) { content.push({ type: "text", text: part.text }); } } } } else { const text = extractTextContent(choice.message.content); if (text != null && text.length > 0) { content.push({ type: "text", text }); } } if (choice.message.tool_calls != null) { for (const toolCall of choice.message.tool_calls) { content.push({ type: "tool-call", toolCallId: toolCall.id, toolName: toolCall.function.name, input: toolCall.function.arguments }); } } return { content, finishReason: mapMistralFinishReason(choice.finish_reason), usage: { inputTokens: response.usage.prompt_tokens, outputTokens: response.usage.completion_tokens, totalTokens: response.usage.total_tokens }, request: { body }, response: { ...getResponseMetadata(response), headers: responseHeaders, body: rawResponse }, warnings }; } async doStream(options) { const { args, warnings } = await this.getArgs(options); const body = { ...args, stream: true }; const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/chat/completions`, headers: combineHeaders(this.config.headers(), options.headers), body, failedResponseHandler: mistralFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( mistralChatChunkSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); let finishReason = "unknown"; const usage = { inputTokens: void 0, outputTokens: void 0, totalTokens: void 0 }; let isFirstChunk = true; let activeText = false; let activeReasoningId = null; const generateId2 = this.generateId; return { stream: response.pipeThrough( new TransformStream({ start(controller) { controller.enqueue({ type: "stream-start", warnings }); }, transform(chunk, controller) { if (options.includeRawChunks) { controller.enqueue({ type: "raw", rawValue: chunk.rawValue }); } if (!chunk.success) { controller.enqueue({ type: "error", error: chunk.error }); return; } const value = chunk.value; if (isFirstChunk) { isFirstChunk = false; controller.enqueue({ type: "response-metadata", ...getResponseMetadata(value) }); } if (value.usage != null) { usage.inputTokens = value.usage.prompt_tokens; usage.outputTokens = value.usage.completion_tokens; usage.totalTokens = value.usage.total_tokens; } const choice = value.choices[0]; const delta = choice.delta; const textContent = extractTextContent(delta.content); if (delta.content != null && Array.isArray(delta.content)) { for (const part of delta.content) { if (part.type === "thinking") { const reasoningDelta = extractReasoningContent(part.thinking); if (reasoningDelta.length > 0) { if (activeReasoningId == null) { if (activeText) { controller.enqueue({ type: "text-end", id: "0" }); activeText = false; } activeReasoningId = generateId2(); controller.enqueue({ type: "reasoning-start", id: activeReasoningId }); } controller.enqueue({ type: "reasoning-delta", id: activeReasoningId, delta: reasoningDelta }); } } } } if (textContent != null && textContent.length > 0) { if (!activeText) { if (activeReasoningId != null) { controller.enqueue({ type: "reasoning-end", id: activeReasoningId }); activeReasoningId = null; } controller.enqueue({ type: "text-start", id: "0" }); activeText = true; } controller.enqueue({ type: "text-delta", id: "0", delta: textContent }); } if ((delta == null ? void 0 : delta.tool_calls) != null) { for (const toolCall of delta.tool_calls) { const toolCallId = toolCall.id; const toolName = toolCall.function.name; const input = toolCall.function.arguments; controller.enqueue({ type: "tool-input-start", id: toolCallId, toolName }); controller.enqueue({ type: "tool-input-delta", id: toolCallId, delta: input }); controller.enqueue({ type: "tool-input-end", id: toolCallId }); controller.enqueue({ type: "tool-call", toolCallId, toolName, input }); } } if (choice.finish_reason != null) { finishReason = mapMistralFinishReason(choice.finish_reason); } }, flush(controller) { if (activeReasoningId != null) { controller.enqueue({ type: "reasoning-end", id: activeReasoningId }); } if (activeText) { controller.enqueue({ type: "text-end", id: "0" }); } controller.enqueue({ type: "finish", finishReason, usage }); } }) ), request: { body }, response: { headers: responseHeaders } }; } }; function extractReasoningContent(thinking) { return thinking.filter((chunk) => chunk.type === "text").map((chunk) => chunk.text).join(""); } function extractTextContent(content) { if (typeof content === "string") { return content; } if (content == null) { return void 0; } const textContent = []; for (const chunk of content) { const { type } = chunk; switch (type) { case "text": textContent.push(chunk.text); break; case "thinking": case "image_url": case "reference": break; default: { const _exhaustiveCheck = type; throw new Error(`Unsupported type: ${_exhaustiveCheck}`); } } } return textContent.length ? textContent.join("") : void 0; } var mistralContentSchema = z3.union([ z3.string(), z3.array( z3.discriminatedUnion("type", [ z3.object({ type: z3.literal("text"), text: z3.string() }), z3.object({ type: z3.literal("image_url"), image_url: z3.union([ z3.string(), z3.object({ url: z3.string(), detail: z3.string().nullable() }) ]) }), z3.object({ type: z3.literal("reference"), reference_ids: z3.array(z3.number()) }), z3.object({ type: z3.literal("thinking"), thinking: z3.array( z3.object({ type: z3.literal("text"), text: z3.string() }) ) }) ]) ) ]).nullish(); var mistralUsageSchema = z3.object({ prompt_tokens: z3.number(), completion_tokens: z3.number(), total_tokens: z3.number() }); var mistralChatResponseSchema = z3.object({ id: z3.string().nullish(), created: z3.number().nullish(), model: z3.string().nullish(), choices: z3.array( z3.object({ message: z3.object({ role: z3.literal("assistant"), content: mistralContentSchema, tool_calls: z3.array( z3.object({ id: z3.string(), function: z3.object({ name: z3.string(), arguments: z3.string() }) }) ).nullish() }), index: z3.number(), finish_reason: z3.string().nullish() }) ), object: z3.literal("chat.completion"), usage: mistralUsageSchema }); var mistralChatChunkSchema = z3.object({ id: z3.string().nullish(), created: z3.number().nullish(), model: z3.string().nullish(), choices: z3.array( z3.object({ delta: z3.object({ role: z3.enum(["assistant"]).optional(), content: mistralContentSchema, tool_calls: z3.array( z3.object({ id: z3.string(), function: z3.object({ name: z3.string(), arguments: z3.string() }) }) ).nullish() }), finish_reason: z3.string().nullish(), index: z3.number() }) ), usage: mistralUsageSchema.nullish() }); // src/mistral-embedding-model.ts import { TooManyEmbeddingValuesForCallError } from "@ai-sdk/provider"; import { combineHeaders as combineHeaders2, createJsonResponseHandler as createJsonResponseHandler2, postJsonToApi as postJsonToApi2 } from "@ai-sdk/provider-utils"; import { z as z4 } from "zod/v4"; var MistralEmbeddingModel = class { constructor(modelId, config) { this.specificationVersion = "v2"; this.maxEmbeddingsPerCall = 32; this.supportsParallelCalls = false; this.modelId = modelId; this.config = config; } get provider() { return this.config.provider; } async doEmbed({ values, abortSignal, headers }) { if (values.length > this.maxEmbeddingsPerCall) { throw new TooManyEmbeddingValuesForCallError({ provider: this.provider, modelId: this.modelId, maxEmbeddingsPerCall: this.maxEmbeddingsPerCall, values }); } const { responseHeaders, value: response, rawValue } = await postJsonToApi2({ url: `${this.config.baseURL}/embeddings`, headers: combineHeaders2(this.config.headers(), headers), body: { model: this.modelId, input: values, encoding_format: "float" }, failedResponseHandler: mistralFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler2( MistralTextEmbeddingResponseSchema ), abortSignal, fetch: this.config.fetch }); return { embeddings: response.data.map((item) => item.embedding), usage: response.usage ? { tokens: response.usage.prompt_tokens } : void 0, response: { headers: responseHeaders, body: rawValue } }; } }; var MistralTextEmbeddingResponseSchema = z4.object({ data: z4.array(z4.object({ embedding: z4.array(z4.number()) })), usage: z4.object({ prompt_tokens: z4.number() }).nullish() }); // src/mistral-provider.ts function createMistral(options = {}) { var _a; const baseURL = (_a = withoutTrailingSlash(options.baseURL)) != null ? _a : "https://api.mistral.ai/v1"; const getHeaders = () => ({ Authorization: `Bearer ${loadApiKey({ apiKey: options.apiKey, environmentVariableName: "MISTRAL_API_KEY", description: "Mistral" })}`, ...options.headers }); const createChatModel = (modelId) => new MistralChatLanguageModel(modelId, { provider: "mistral.chat", baseURL, headers: getHeaders, fetch: options.fetch, generateId: options.generateId }); const createEmbeddingModel = (modelId) => new MistralEmbeddingModel(modelId, { provider: "mistral.embedding", baseURL, headers: getHeaders, fetch: options.fetch }); const provider = function(modelId) { if (new.target) { throw new Error( "The Mistral model function cannot be called with the new keyword." ); } return createChatModel(modelId); }; provider.languageModel = createChatModel; provider.chat = createChatModel; provider.embedding = createEmbeddingModel; provider.textEmbedding = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; provider.imageModel = (modelId) => { throw new NoSuchModelError({ modelId, modelType: "imageModel" }); }; return provider; } var mistral = createMistral(); export { createMistral, mistral }; //# sourceMappingURL=index.mjs.map