UNPKG

@ai-toolkit/mistral

Version:

The **[Mistral provider](https://sdk.khulnasoft.com/providers/ai-toolkit-providers/mistral)** for the [AI TOOLKIT](https://sdk.khulnasoft.com/docs) contains language model support for the Mistral chat API.

703 lines (694 loc) 21 kB
// src/mistral-provider.ts import { loadApiKey, withoutTrailingSlash } from "@ai-toolkit/provider-utils"; // src/mistral-chat-language-model.ts import { combineHeaders, createEventSourceResponseHandler, createJsonResponseHandler, postJsonToApi } from "@ai-toolkit/provider-utils"; import { z as z2 } from "zod"; // src/convert-to-mistral-chat-messages.ts import { UnsupportedFunctionalityError } from "@ai-toolkit/provider"; import { convertUint8ArrayToBase64 } from "@ai-toolkit/provider-utils"; function convertToMistralChatMessages(prompt) { const messages = []; for (let i = 0; i < prompt.length; i++) { const { role, content } = prompt[i]; const isLastMessage = i === prompt.length - 1; switch (role) { case "system": { messages.push({ role: "system", content }); break; } case "user": { messages.push({ role: "user", content: content.map((part) => { var _a; switch (part.type) { case "text": { return { type: "text", text: part.text }; } case "image": { return { type: "image_url", image_url: part.image instanceof URL ? part.image.toString() : `data:${(_a = part.mimeType) != null ? _a : "image/jpeg"};base64,${convertUint8ArrayToBase64(part.image)}` }; } case "file": { if (!(part.data instanceof URL)) { throw new UnsupportedFunctionalityError({ functionality: "File content parts in user messages" }); } switch (part.mimeType) { case "application/pdf": { return { type: "document_url", document_url: part.data.toString() }; } default: { throw new UnsupportedFunctionalityError({ functionality: "Only PDF files are supported in user messages" }); } } } } }) }); break; } case "assistant": { let text = ""; const toolCalls = []; for (const part of content) { switch (part.type) { case "text": { text += part.text; break; } case "redacted-reasoning": case "reasoning": { break; } case "tool-call": { toolCalls.push({ id: part.toolCallId, type: "function", function: { name: part.toolName, arguments: JSON.stringify(part.args) } }); break; } default: { const _exhaustiveCheck = part; throw new Error(`Unsupported part: ${_exhaustiveCheck}`); } } } messages.push({ role: "assistant", content: text, prefix: isLastMessage ? true : void 0, tool_calls: toolCalls.length > 0 ? toolCalls : void 0 }); break; } case "tool": { for (const toolResponse of content) { messages.push({ role: "tool", name: toolResponse.toolName, content: JSON.stringify(toolResponse.result), tool_call_id: toolResponse.toolCallId }); } break; } default: { const _exhaustiveCheck = role; throw new Error(`Unsupported role: ${_exhaustiveCheck}`); } } } return messages; } // src/map-mistral-finish-reason.ts function mapMistralFinishReason(finishReason) { switch (finishReason) { case "stop": return "stop"; case "length": case "model_length": return "length"; case "tool_calls": return "tool-calls"; default: return "unknown"; } } // src/mistral-error.ts import { createJsonErrorResponseHandler } from "@ai-toolkit/provider-utils"; import { z } from "zod"; var mistralErrorDataSchema = z.object({ object: z.literal("error"), message: z.string(), type: z.string(), param: z.string().nullable(), code: z.string().nullable() }); var mistralFailedResponseHandler = createJsonErrorResponseHandler({ errorSchema: mistralErrorDataSchema, errorToMessage: (data) => data.message }); // src/get-response-metadata.ts function getResponseMetadata({ id, model, created }) { return { id: id != null ? id : void 0, modelId: model != null ? model : void 0, timestamp: created != null ? new Date(created * 1e3) : void 0 }; } // src/mistral-prepare-tools.ts import { UnsupportedFunctionalityError as UnsupportedFunctionalityError2 } from "@ai-toolkit/provider"; function prepareTools(mode) { var _a; const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0; const toolWarnings = []; if (tools == null) { return { tools: void 0, tool_choice: void 0, toolWarnings }; } const mistralTools = []; for (const tool of tools) { if (tool.type === "provider-defined") { toolWarnings.push({ type: "unsupported-tool", tool }); } else { mistralTools.push({ type: "function", function: { name: tool.name, description: tool.description, parameters: tool.parameters } }); } } const toolChoice = mode.toolChoice; if (toolChoice == null) { return { tools: mistralTools, tool_choice: void 0, toolWarnings }; } const type = toolChoice.type; switch (type) { case "auto": case "none": return { tools: mistralTools, tool_choice: type, toolWarnings }; case "required": return { tools: mistralTools, tool_choice: "any", toolWarnings }; case "tool": return { tools: mistralTools.filter( (tool) => tool.function.name === toolChoice.toolName ), tool_choice: "any", toolWarnings }; default: { const _exhaustiveCheck = type; throw new UnsupportedFunctionalityError2({ functionality: `Unsupported tool choice type: ${_exhaustiveCheck}` }); } } } // src/mistral-chat-language-model.ts var MistralChatLanguageModel = class { constructor(modelId, settings, config) { this.specificationVersion = "v1"; this.defaultObjectGenerationMode = "json"; this.supportsImageUrls = false; this.modelId = modelId; this.settings = settings; this.config = config; } get provider() { return this.config.provider; } supportsUrl(url) { return url.protocol === "https:"; } getArgs({ mode, prompt, maxTokens, temperature, topP, topK, frequencyPenalty, presencePenalty, stopSequences, responseFormat, seed, providerMetadata }) { var _a, _b; const type = mode.type; const warnings = []; if (topK != null) { warnings.push({ type: "unsupported-setting", setting: "topK" }); } if (frequencyPenalty != null) { warnings.push({ type: "unsupported-setting", setting: "frequencyPenalty" }); } if (presencePenalty != null) { warnings.push({ type: "unsupported-setting", setting: "presencePenalty" }); } if (stopSequences != null) { warnings.push({ type: "unsupported-setting", setting: "stopSequences" }); } if (responseFormat != null && responseFormat.type === "json" && responseFormat.schema != null) { warnings.push({ type: "unsupported-setting", setting: "responseFormat", details: "JSON response format schema is not supported" }); } const baseArgs = { // model id: model: this.modelId, // model specific settings: safe_prompt: this.settings.safePrompt, // standardized settings: max_tokens: maxTokens, temperature, top_p: topP, random_seed: seed, // response format: response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object" } : void 0, // mistral-specific provider options: document_image_limit: (_a = providerMetadata == null ? void 0 : providerMetadata.mistral) == null ? void 0 : _a.documentImageLimit, document_page_limit: (_b = providerMetadata == null ? void 0 : providerMetadata.mistral) == null ? void 0 : _b.documentPageLimit, // messages: messages: convertToMistralChatMessages(prompt) }; switch (type) { case "regular": { const { tools, tool_choice, toolWarnings } = prepareTools(mode); return { args: { ...baseArgs, tools, tool_choice }, warnings: [...warnings, ...toolWarnings] }; } case "object-json": { return { args: { ...baseArgs, response_format: { type: "json_object" } }, warnings }; } case "object-tool": { return { args: { ...baseArgs, tool_choice: "any", tools: [{ type: "function", function: mode.tool }] }, warnings }; } default: { const _exhaustiveCheck = type; throw new Error(`Unsupported type: ${_exhaustiveCheck}`); } } } async doGenerate(options) { var _a; const { args, warnings } = this.getArgs(options); const { responseHeaders, value: response, rawValue: rawResponse } = await postJsonToApi({ url: `${this.config.baseURL}/chat/completions`, headers: combineHeaders(this.config.headers(), options.headers), body: args, failedResponseHandler: mistralFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( mistralChatResponseSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); const { messages: rawPrompt, ...rawSettings } = args; const choice = response.choices[0]; let text = extractTextContent(choice.message.content); const lastMessage = rawPrompt[rawPrompt.length - 1]; if (lastMessage.role === "assistant" && (text == null ? void 0 : text.startsWith(lastMessage.content))) { text = text.slice(lastMessage.content.length); } return { text, toolCalls: (_a = choice.message.tool_calls) == null ? void 0 : _a.map((toolCall) => ({ toolCallType: "function", toolCallId: toolCall.id, toolName: toolCall.function.name, args: toolCall.function.arguments })), finishReason: mapMistralFinishReason(choice.finish_reason), usage: { promptTokens: response.usage.prompt_tokens, completionTokens: response.usage.completion_tokens }, rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders, body: rawResponse }, request: { body: JSON.stringify(args) }, response: getResponseMetadata(response), warnings }; } async doStream(options) { const { args, warnings } = this.getArgs(options); const body = { ...args, stream: true }; const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/chat/completions`, headers: combineHeaders(this.config.headers(), options.headers), body, failedResponseHandler: mistralFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( mistralChatChunkSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); const { messages: rawPrompt, ...rawSettings } = args; let finishReason = "unknown"; let usage = { promptTokens: Number.NaN, completionTokens: Number.NaN }; let chunkNumber = 0; let trimLeadingSpace = false; return { stream: response.pipeThrough( new TransformStream({ transform(chunk, controller) { if (!chunk.success) { controller.enqueue({ type: "error", error: chunk.error }); return; } chunkNumber++; const value = chunk.value; if (chunkNumber === 1) { controller.enqueue({ type: "response-metadata", ...getResponseMetadata(value) }); } if (value.usage != null) { usage = { promptTokens: value.usage.prompt_tokens, completionTokens: value.usage.completion_tokens }; } const choice = value.choices[0]; if ((choice == null ? void 0 : choice.finish_reason) != null) { finishReason = mapMistralFinishReason(choice.finish_reason); } if ((choice == null ? void 0 : choice.delta) == null) { return; } const delta = choice.delta; const textContent = extractTextContent(delta.content); if (chunkNumber <= 2) { const lastMessage = rawPrompt[rawPrompt.length - 1]; if (lastMessage.role === "assistant" && textContent === lastMessage.content.trimEnd()) { if (textContent.length < lastMessage.content.length) { trimLeadingSpace = true; } return; } } if (textContent != null) { controller.enqueue({ type: "text-delta", textDelta: trimLeadingSpace ? textContent.trimStart() : textContent }); trimLeadingSpace = false; } if (delta.tool_calls != null) { for (const toolCall of delta.tool_calls) { controller.enqueue({ type: "tool-call-delta", toolCallType: "function", toolCallId: toolCall.id, toolName: toolCall.function.name, argsTextDelta: toolCall.function.arguments }); controller.enqueue({ type: "tool-call", toolCallType: "function", toolCallId: toolCall.id, toolName: toolCall.function.name, args: toolCall.function.arguments }); } } }, flush(controller) { controller.enqueue({ type: "finish", finishReason, usage }); } }) ), rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, request: { body: JSON.stringify(body) }, warnings }; } }; function extractTextContent(content) { if (typeof content === "string") { return content; } if (content == null) { return void 0; } const textContent = []; for (const chunk of content) { const { type } = chunk; switch (type) { case "text": textContent.push(chunk.text); break; case "image_url": case "reference": break; default: { const _exhaustiveCheck = type; throw new Error(`Unsupported type: ${_exhaustiveCheck}`); } } } return textContent.length ? textContent.join("") : void 0; } var mistralContentSchema = z2.union([ z2.string(), z2.array( z2.discriminatedUnion("type", [ z2.object({ type: z2.literal("text"), text: z2.string() }), z2.object({ type: z2.literal("image_url"), image_url: z2.union([ z2.string(), z2.object({ url: z2.string(), detail: z2.string().nullable() }) ]) }), z2.object({ type: z2.literal("reference"), reference_ids: z2.array(z2.number()) }) ]) ) ]).nullish(); var mistralChatResponseSchema = z2.object({ id: z2.string().nullish(), created: z2.number().nullish(), model: z2.string().nullish(), choices: z2.array( z2.object({ message: z2.object({ role: z2.literal("assistant"), content: mistralContentSchema, tool_calls: z2.array( z2.object({ id: z2.string(), function: z2.object({ name: z2.string(), arguments: z2.string() }) }) ).nullish() }), index: z2.number(), finish_reason: z2.string().nullish() }) ), object: z2.literal("chat.completion"), usage: z2.object({ prompt_tokens: z2.number(), completion_tokens: z2.number() }) }); var mistralChatChunkSchema = z2.object({ id: z2.string().nullish(), created: z2.number().nullish(), model: z2.string().nullish(), choices: z2.array( z2.object({ delta: z2.object({ role: z2.enum(["assistant"]).optional(), content: mistralContentSchema, tool_calls: z2.array( z2.object({ id: z2.string(), function: z2.object({ name: z2.string(), arguments: z2.string() }) }) ).nullish() }), finish_reason: z2.string().nullish(), index: z2.number() }) ), usage: z2.object({ prompt_tokens: z2.number(), completion_tokens: z2.number() }).nullish() }); // src/mistral-embedding-model.ts import { TooManyEmbeddingValuesForCallError } from "@ai-toolkit/provider"; import { combineHeaders as combineHeaders2, createJsonResponseHandler as createJsonResponseHandler2, postJsonToApi as postJsonToApi2 } from "@ai-toolkit/provider-utils"; import { z as z3 } from "zod"; var MistralEmbeddingModel = class { constructor(modelId, settings, config) { this.specificationVersion = "v1"; this.modelId = modelId; this.settings = settings; this.config = config; } get provider() { return this.config.provider; } get maxEmbeddingsPerCall() { var _a; return (_a = this.settings.maxEmbeddingsPerCall) != null ? _a : 32; } get supportsParallelCalls() { var _a; return (_a = this.settings.supportsParallelCalls) != null ? _a : false; } async doEmbed({ values, abortSignal, headers }) { if (values.length > this.maxEmbeddingsPerCall) { throw new TooManyEmbeddingValuesForCallError({ provider: this.provider, modelId: this.modelId, maxEmbeddingsPerCall: this.maxEmbeddingsPerCall, values }); } const { responseHeaders, value: response } = await postJsonToApi2({ url: `${this.config.baseURL}/embeddings`, headers: combineHeaders2(this.config.headers(), headers), body: { model: this.modelId, input: values, encoding_format: "float" }, failedResponseHandler: mistralFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler2( MistralTextEmbeddingResponseSchema ), abortSignal, fetch: this.config.fetch }); return { embeddings: response.data.map((item) => item.embedding), usage: response.usage ? { tokens: response.usage.prompt_tokens } : void 0, rawResponse: { headers: responseHeaders } }; } }; var MistralTextEmbeddingResponseSchema = z3.object({ data: z3.array(z3.object({ embedding: z3.array(z3.number()) })), usage: z3.object({ prompt_tokens: z3.number() }).nullish() }); // src/mistral-provider.ts function createMistral(options = {}) { var _a; const baseURL = (_a = withoutTrailingSlash(options.baseURL)) != null ? _a : "https://api.mistral.ai/v1"; const getHeaders = () => ({ Authorization: `Bearer ${loadApiKey({ apiKey: options.apiKey, environmentVariableName: "MISTRAL_API_KEY", description: "Mistral" })}`, ...options.headers }); const createChatModel = (modelId, settings = {}) => new MistralChatLanguageModel(modelId, settings, { provider: "mistral.chat", baseURL, headers: getHeaders, fetch: options.fetch }); const createEmbeddingModel = (modelId, settings = {}) => new MistralEmbeddingModel(modelId, settings, { provider: "mistral.embedding", baseURL, headers: getHeaders, fetch: options.fetch }); const provider = function(modelId, settings) { if (new.target) { throw new Error( "The Mistral model function cannot be called with the new keyword." ); } return createChatModel(modelId, settings); }; provider.languageModel = createChatModel; provider.chat = createChatModel; provider.embedding = createEmbeddingModel; provider.textEmbedding = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; return provider; } var mistral = createMistral(); export { createMistral, mistral }; //# sourceMappingURL=index.mjs.map