@ai-toolkit/mistral
Version:
The **[Mistral provider](https://sdk.khulnasoft.com/providers/ai-toolkit-providers/mistral)** for the [AI TOOLKIT](https://sdk.khulnasoft.com/docs) contains language model support for the Mistral chat API.
703 lines (694 loc) • 21 kB
JavaScript
// src/mistral-provider.ts
import {
loadApiKey,
withoutTrailingSlash
} from "@ai-toolkit/provider-utils";
// src/mistral-chat-language-model.ts
import {
combineHeaders,
createEventSourceResponseHandler,
createJsonResponseHandler,
postJsonToApi
} from "@ai-toolkit/provider-utils";
import { z as z2 } from "zod";
// src/convert-to-mistral-chat-messages.ts
import {
UnsupportedFunctionalityError
} from "@ai-toolkit/provider";
import { convertUint8ArrayToBase64 } from "@ai-toolkit/provider-utils";
function convertToMistralChatMessages(prompt) {
const messages = [];
for (let i = 0; i < prompt.length; i++) {
const { role, content } = prompt[i];
const isLastMessage = i === prompt.length - 1;
switch (role) {
case "system": {
messages.push({ role: "system", content });
break;
}
case "user": {
messages.push({
role: "user",
content: content.map((part) => {
var _a;
switch (part.type) {
case "text": {
return { type: "text", text: part.text };
}
case "image": {
return {
type: "image_url",
image_url: part.image instanceof URL ? part.image.toString() : `data:${(_a = part.mimeType) != null ? _a : "image/jpeg"};base64,${convertUint8ArrayToBase64(part.image)}`
};
}
case "file": {
if (!(part.data instanceof URL)) {
throw new UnsupportedFunctionalityError({
functionality: "File content parts in user messages"
});
}
switch (part.mimeType) {
case "application/pdf": {
return {
type: "document_url",
document_url: part.data.toString()
};
}
default: {
throw new UnsupportedFunctionalityError({
functionality: "Only PDF files are supported in user messages"
});
}
}
}
}
})
});
break;
}
case "assistant": {
let text = "";
const toolCalls = [];
for (const part of content) {
switch (part.type) {
case "text": {
text += part.text;
break;
}
case "redacted-reasoning":
case "reasoning": {
break;
}
case "tool-call": {
toolCalls.push({
id: part.toolCallId,
type: "function",
function: {
name: part.toolName,
arguments: JSON.stringify(part.args)
}
});
break;
}
default: {
const _exhaustiveCheck = part;
throw new Error(`Unsupported part: ${_exhaustiveCheck}`);
}
}
}
messages.push({
role: "assistant",
content: text,
prefix: isLastMessage ? true : void 0,
tool_calls: toolCalls.length > 0 ? toolCalls : void 0
});
break;
}
case "tool": {
for (const toolResponse of content) {
messages.push({
role: "tool",
name: toolResponse.toolName,
content: JSON.stringify(toolResponse.result),
tool_call_id: toolResponse.toolCallId
});
}
break;
}
default: {
const _exhaustiveCheck = role;
throw new Error(`Unsupported role: ${_exhaustiveCheck}`);
}
}
}
return messages;
}
// src/map-mistral-finish-reason.ts
function mapMistralFinishReason(finishReason) {
switch (finishReason) {
case "stop":
return "stop";
case "length":
case "model_length":
return "length";
case "tool_calls":
return "tool-calls";
default:
return "unknown";
}
}
// src/mistral-error.ts
import { createJsonErrorResponseHandler } from "@ai-toolkit/provider-utils";
import { z } from "zod";
var mistralErrorDataSchema = z.object({
object: z.literal("error"),
message: z.string(),
type: z.string(),
param: z.string().nullable(),
code: z.string().nullable()
});
var mistralFailedResponseHandler = createJsonErrorResponseHandler({
errorSchema: mistralErrorDataSchema,
errorToMessage: (data) => data.message
});
// src/get-response-metadata.ts
function getResponseMetadata({
id,
model,
created
}) {
return {
id: id != null ? id : void 0,
modelId: model != null ? model : void 0,
timestamp: created != null ? new Date(created * 1e3) : void 0
};
}
// src/mistral-prepare-tools.ts
import {
UnsupportedFunctionalityError as UnsupportedFunctionalityError2
} from "@ai-toolkit/provider";
function prepareTools(mode) {
var _a;
const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0;
const toolWarnings = [];
if (tools == null) {
return { tools: void 0, tool_choice: void 0, toolWarnings };
}
const mistralTools = [];
for (const tool of tools) {
if (tool.type === "provider-defined") {
toolWarnings.push({ type: "unsupported-tool", tool });
} else {
mistralTools.push({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters
}
});
}
}
const toolChoice = mode.toolChoice;
if (toolChoice == null) {
return { tools: mistralTools, tool_choice: void 0, toolWarnings };
}
const type = toolChoice.type;
switch (type) {
case "auto":
case "none":
return { tools: mistralTools, tool_choice: type, toolWarnings };
case "required":
return { tools: mistralTools, tool_choice: "any", toolWarnings };
case "tool":
return {
tools: mistralTools.filter(
(tool) => tool.function.name === toolChoice.toolName
),
tool_choice: "any",
toolWarnings
};
default: {
const _exhaustiveCheck = type;
throw new UnsupportedFunctionalityError2({
functionality: `Unsupported tool choice type: ${_exhaustiveCheck}`
});
}
}
}
// src/mistral-chat-language-model.ts
var MistralChatLanguageModel = class {
constructor(modelId, settings, config) {
this.specificationVersion = "v1";
this.defaultObjectGenerationMode = "json";
this.supportsImageUrls = false;
this.modelId = modelId;
this.settings = settings;
this.config = config;
}
get provider() {
return this.config.provider;
}
supportsUrl(url) {
return url.protocol === "https:";
}
getArgs({
mode,
prompt,
maxTokens,
temperature,
topP,
topK,
frequencyPenalty,
presencePenalty,
stopSequences,
responseFormat,
seed,
providerMetadata
}) {
var _a, _b;
const type = mode.type;
const warnings = [];
if (topK != null) {
warnings.push({
type: "unsupported-setting",
setting: "topK"
});
}
if (frequencyPenalty != null) {
warnings.push({
type: "unsupported-setting",
setting: "frequencyPenalty"
});
}
if (presencePenalty != null) {
warnings.push({
type: "unsupported-setting",
setting: "presencePenalty"
});
}
if (stopSequences != null) {
warnings.push({
type: "unsupported-setting",
setting: "stopSequences"
});
}
if (responseFormat != null && responseFormat.type === "json" && responseFormat.schema != null) {
warnings.push({
type: "unsupported-setting",
setting: "responseFormat",
details: "JSON response format schema is not supported"
});
}
const baseArgs = {
// model id:
model: this.modelId,
// model specific settings:
safe_prompt: this.settings.safePrompt,
// standardized settings:
max_tokens: maxTokens,
temperature,
top_p: topP,
random_seed: seed,
// response format:
response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object" } : void 0,
// mistral-specific provider options:
document_image_limit: (_a = providerMetadata == null ? void 0 : providerMetadata.mistral) == null ? void 0 : _a.documentImageLimit,
document_page_limit: (_b = providerMetadata == null ? void 0 : providerMetadata.mistral) == null ? void 0 : _b.documentPageLimit,
// messages:
messages: convertToMistralChatMessages(prompt)
};
switch (type) {
case "regular": {
const { tools, tool_choice, toolWarnings } = prepareTools(mode);
return {
args: { ...baseArgs, tools, tool_choice },
warnings: [...warnings, ...toolWarnings]
};
}
case "object-json": {
return {
args: {
...baseArgs,
response_format: { type: "json_object" }
},
warnings
};
}
case "object-tool": {
return {
args: {
...baseArgs,
tool_choice: "any",
tools: [{ type: "function", function: mode.tool }]
},
warnings
};
}
default: {
const _exhaustiveCheck = type;
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
}
}
}
async doGenerate(options) {
var _a;
const { args, warnings } = this.getArgs(options);
const {
responseHeaders,
value: response,
rawValue: rawResponse
} = await postJsonToApi({
url: `${this.config.baseURL}/chat/completions`,
headers: combineHeaders(this.config.headers(), options.headers),
body: args,
failedResponseHandler: mistralFailedResponseHandler,
successfulResponseHandler: createJsonResponseHandler(
mistralChatResponseSchema
),
abortSignal: options.abortSignal,
fetch: this.config.fetch
});
const { messages: rawPrompt, ...rawSettings } = args;
const choice = response.choices[0];
let text = extractTextContent(choice.message.content);
const lastMessage = rawPrompt[rawPrompt.length - 1];
if (lastMessage.role === "assistant" && (text == null ? void 0 : text.startsWith(lastMessage.content))) {
text = text.slice(lastMessage.content.length);
}
return {
text,
toolCalls: (_a = choice.message.tool_calls) == null ? void 0 : _a.map((toolCall) => ({
toolCallType: "function",
toolCallId: toolCall.id,
toolName: toolCall.function.name,
args: toolCall.function.arguments
})),
finishReason: mapMistralFinishReason(choice.finish_reason),
usage: {
promptTokens: response.usage.prompt_tokens,
completionTokens: response.usage.completion_tokens
},
rawCall: { rawPrompt, rawSettings },
rawResponse: {
headers: responseHeaders,
body: rawResponse
},
request: { body: JSON.stringify(args) },
response: getResponseMetadata(response),
warnings
};
}
async doStream(options) {
const { args, warnings } = this.getArgs(options);
const body = { ...args, stream: true };
const { responseHeaders, value: response } = await postJsonToApi({
url: `${this.config.baseURL}/chat/completions`,
headers: combineHeaders(this.config.headers(), options.headers),
body,
failedResponseHandler: mistralFailedResponseHandler,
successfulResponseHandler: createEventSourceResponseHandler(
mistralChatChunkSchema
),
abortSignal: options.abortSignal,
fetch: this.config.fetch
});
const { messages: rawPrompt, ...rawSettings } = args;
let finishReason = "unknown";
let usage = {
promptTokens: Number.NaN,
completionTokens: Number.NaN
};
let chunkNumber = 0;
let trimLeadingSpace = false;
return {
stream: response.pipeThrough(
new TransformStream({
transform(chunk, controller) {
if (!chunk.success) {
controller.enqueue({ type: "error", error: chunk.error });
return;
}
chunkNumber++;
const value = chunk.value;
if (chunkNumber === 1) {
controller.enqueue({
type: "response-metadata",
...getResponseMetadata(value)
});
}
if (value.usage != null) {
usage = {
promptTokens: value.usage.prompt_tokens,
completionTokens: value.usage.completion_tokens
};
}
const choice = value.choices[0];
if ((choice == null ? void 0 : choice.finish_reason) != null) {
finishReason = mapMistralFinishReason(choice.finish_reason);
}
if ((choice == null ? void 0 : choice.delta) == null) {
return;
}
const delta = choice.delta;
const textContent = extractTextContent(delta.content);
if (chunkNumber <= 2) {
const lastMessage = rawPrompt[rawPrompt.length - 1];
if (lastMessage.role === "assistant" && textContent === lastMessage.content.trimEnd()) {
if (textContent.length < lastMessage.content.length) {
trimLeadingSpace = true;
}
return;
}
}
if (textContent != null) {
controller.enqueue({
type: "text-delta",
textDelta: trimLeadingSpace ? textContent.trimStart() : textContent
});
trimLeadingSpace = false;
}
if (delta.tool_calls != null) {
for (const toolCall of delta.tool_calls) {
controller.enqueue({
type: "tool-call-delta",
toolCallType: "function",
toolCallId: toolCall.id,
toolName: toolCall.function.name,
argsTextDelta: toolCall.function.arguments
});
controller.enqueue({
type: "tool-call",
toolCallType: "function",
toolCallId: toolCall.id,
toolName: toolCall.function.name,
args: toolCall.function.arguments
});
}
}
},
flush(controller) {
controller.enqueue({ type: "finish", finishReason, usage });
}
})
),
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
request: { body: JSON.stringify(body) },
warnings
};
}
};
function extractTextContent(content) {
if (typeof content === "string") {
return content;
}
if (content == null) {
return void 0;
}
const textContent = [];
for (const chunk of content) {
const { type } = chunk;
switch (type) {
case "text":
textContent.push(chunk.text);
break;
case "image_url":
case "reference":
break;
default: {
const _exhaustiveCheck = type;
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
}
}
}
return textContent.length ? textContent.join("") : void 0;
}
var mistralContentSchema = z2.union([
z2.string(),
z2.array(
z2.discriminatedUnion("type", [
z2.object({
type: z2.literal("text"),
text: z2.string()
}),
z2.object({
type: z2.literal("image_url"),
image_url: z2.union([
z2.string(),
z2.object({
url: z2.string(),
detail: z2.string().nullable()
})
])
}),
z2.object({
type: z2.literal("reference"),
reference_ids: z2.array(z2.number())
})
])
)
]).nullish();
var mistralChatResponseSchema = z2.object({
id: z2.string().nullish(),
created: z2.number().nullish(),
model: z2.string().nullish(),
choices: z2.array(
z2.object({
message: z2.object({
role: z2.literal("assistant"),
content: mistralContentSchema,
tool_calls: z2.array(
z2.object({
id: z2.string(),
function: z2.object({ name: z2.string(), arguments: z2.string() })
})
).nullish()
}),
index: z2.number(),
finish_reason: z2.string().nullish()
})
),
object: z2.literal("chat.completion"),
usage: z2.object({
prompt_tokens: z2.number(),
completion_tokens: z2.number()
})
});
var mistralChatChunkSchema = z2.object({
id: z2.string().nullish(),
created: z2.number().nullish(),
model: z2.string().nullish(),
choices: z2.array(
z2.object({
delta: z2.object({
role: z2.enum(["assistant"]).optional(),
content: mistralContentSchema,
tool_calls: z2.array(
z2.object({
id: z2.string(),
function: z2.object({ name: z2.string(), arguments: z2.string() })
})
).nullish()
}),
finish_reason: z2.string().nullish(),
index: z2.number()
})
),
usage: z2.object({
prompt_tokens: z2.number(),
completion_tokens: z2.number()
}).nullish()
});
// src/mistral-embedding-model.ts
import {
TooManyEmbeddingValuesForCallError
} from "@ai-toolkit/provider";
import {
combineHeaders as combineHeaders2,
createJsonResponseHandler as createJsonResponseHandler2,
postJsonToApi as postJsonToApi2
} from "@ai-toolkit/provider-utils";
import { z as z3 } from "zod";
var MistralEmbeddingModel = class {
constructor(modelId, settings, config) {
this.specificationVersion = "v1";
this.modelId = modelId;
this.settings = settings;
this.config = config;
}
get provider() {
return this.config.provider;
}
get maxEmbeddingsPerCall() {
var _a;
return (_a = this.settings.maxEmbeddingsPerCall) != null ? _a : 32;
}
get supportsParallelCalls() {
var _a;
return (_a = this.settings.supportsParallelCalls) != null ? _a : false;
}
async doEmbed({
values,
abortSignal,
headers
}) {
if (values.length > this.maxEmbeddingsPerCall) {
throw new TooManyEmbeddingValuesForCallError({
provider: this.provider,
modelId: this.modelId,
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
values
});
}
const { responseHeaders, value: response } = await postJsonToApi2({
url: `${this.config.baseURL}/embeddings`,
headers: combineHeaders2(this.config.headers(), headers),
body: {
model: this.modelId,
input: values,
encoding_format: "float"
},
failedResponseHandler: mistralFailedResponseHandler,
successfulResponseHandler: createJsonResponseHandler2(
MistralTextEmbeddingResponseSchema
),
abortSignal,
fetch: this.config.fetch
});
return {
embeddings: response.data.map((item) => item.embedding),
usage: response.usage ? { tokens: response.usage.prompt_tokens } : void 0,
rawResponse: { headers: responseHeaders }
};
}
};
var MistralTextEmbeddingResponseSchema = z3.object({
data: z3.array(z3.object({ embedding: z3.array(z3.number()) })),
usage: z3.object({ prompt_tokens: z3.number() }).nullish()
});
// src/mistral-provider.ts
function createMistral(options = {}) {
var _a;
const baseURL = (_a = withoutTrailingSlash(options.baseURL)) != null ? _a : "https://api.mistral.ai/v1";
const getHeaders = () => ({
Authorization: `Bearer ${loadApiKey({
apiKey: options.apiKey,
environmentVariableName: "MISTRAL_API_KEY",
description: "Mistral"
})}`,
...options.headers
});
const createChatModel = (modelId, settings = {}) => new MistralChatLanguageModel(modelId, settings, {
provider: "mistral.chat",
baseURL,
headers: getHeaders,
fetch: options.fetch
});
const createEmbeddingModel = (modelId, settings = {}) => new MistralEmbeddingModel(modelId, settings, {
provider: "mistral.embedding",
baseURL,
headers: getHeaders,
fetch: options.fetch
});
const provider = function(modelId, settings) {
if (new.target) {
throw new Error(
"The Mistral model function cannot be called with the new keyword."
);
}
return createChatModel(modelId, settings);
};
provider.languageModel = createChatModel;
provider.chat = createChatModel;
provider.embedding = createEmbeddingModel;
provider.textEmbedding = createEmbeddingModel;
provider.textEmbeddingModel = createEmbeddingModel;
return provider;
}
var mistral = createMistral();
export {
createMistral,
mistral
};
//# sourceMappingURL=index.mjs.map