UNPKG

jorel

Version:

The easiest way to use LLMs, including streams, images, documents, tools and various agent scenarios.

202 lines (201 loc) 8.22 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.MistralProvider = void 0; const mistralai_1 = require("@mistralai/mistralai"); const shared_1 = require("../../shared"); const providers_1 = require("../../providers"); const tools_1 = require("../../tools"); const convert_llm_message_1 = require("./convert-llm-message"); /** Provides access to OpenAI and other compatible services */ class MistralProvider { constructor({ apiKey } = {}) { this.name = "mistral"; this.client = new mistralai_1.Mistral({ apiKey: apiKey ?? process.env.MISTRAL_API_KEY, }); } async generateResponse(model, messages, config = {}) { const start = Date.now(); const temperature = config.temperature ?? undefined; const response = await this.client.chat.complete({ model, messages: await (0, convert_llm_message_1.convertLlmMessagesToMistralMessages)(messages), temperature, responseFormat: (0, providers_1.jsonResponseToOpenAi)(config.json), maxTokens: config.maxTokens, toolChoice: (0, providers_1.toolChoiceToOpenAi)(config.toolChoice), tools: config.tools?.asLlmFunctions?.map((f) => ({ type: "function", function: { name: f.function.name, description: f.function.description, parameters: { type: f.function.parameters?.type ?? "object", properties: f.function.parameters?.properties ?? {}, required: f.function.parameters?.required ?? [], }, }, })), }); const durationMs = Date.now() - start; const inputTokens = response.usage?.promptTokens; const outputTokens = response.usage?.completionTokens; const message = response.choices ? (0, shared_1.firstEntry)(response.choices)?.message : undefined; const content = Array.isArray(message?.content) ? message.content.map((c) => (c.type === "text" ? c.text : "")).join("") : (message?.content ?? null); const toolCalls = message?.toolCalls?.map((call) => { return { id: (0, shared_1.generateUniqueId)(), request: { id: call.id ?? (0, shared_1.generateUniqueId)(), function: { name: call.function.name, arguments: typeof call.function.arguments == "string" ? tools_1.LlmToolKit.deserialize(call.function.arguments) : call.function.arguments, }, }, approvalState: config.tools?.getTool(call.function.name)?.requiresConfirmation ? "requiresApproval" : "noApprovalRequired", executionState: "pending", result: null, error: null, }; }); const provider = this.name; return { ...(0, providers_1.generateAssistantMessage)(content, toolCalls), meta: { model, provider, temperature, durationMs, inputTokens, outputTokens, }, }; } async *generateResponseStream(model, messages, config = {}) { const start = Date.now(); const temperature = config.temperature ?? undefined; const response = await this.client.chat.stream({ model, messages: await (0, convert_llm_message_1.convertLlmMessagesToMistralMessages)(messages), temperature, responseFormat: (0, providers_1.jsonResponseToOpenAi)(config.json), maxTokens: config.maxTokens, stream: true, tools: config.tools?.asLlmFunctions?.map((f) => ({ type: "function", function: { name: f.function.name, description: f.function.description, parameters: { type: f.function.parameters?.type ?? "object", properties: f.function.parameters?.properties ?? {}, required: f.function.parameters?.required ?? [], }, }, })), toolChoice: (0, providers_1.toolChoiceToOpenAi)(config.toolChoice), }); let inputTokens; let outputTokens; const _toolCalls = []; let content = ""; for await (const chunk of response) { const delta = (0, shared_1.firstEntry)(chunk.data.choices)?.delta; if (delta?.content) { content += delta.content; yield { type: "chunk", content: typeof delta.content === "string" ? delta.content : delta.content.map((c) => (c.type === "text" ? c.text : "")).join(""), }; } if (delta?.toolCalls) { for (const toolCall of delta.toolCalls) { if (toolCall.index !== undefined) { const _toolCall = _toolCalls[toolCall.index] || { id: "", function: { name: "", arguments: "" } }; if (toolCall.id) _toolCall.id += toolCall.id; if (toolCall.function) { if (toolCall.function.name) _toolCall.function.name += toolCall.function.name; if (toolCall.function.arguments) _toolCall.function.arguments += toolCall.function.arguments; } _toolCalls[toolCall.index] = _toolCall; } } } if (chunk.data.usage) { inputTokens = chunk.data.usage?.promptTokens; outputTokens = chunk.data.usage?.completionTokens; } } const durationMs = Date.now() - start; const provider = this.name; const toolCalls = _toolCalls.map((call) => { return { id: (0, shared_1.generateUniqueId)(), request: { id: call.id, function: { name: call.function.name, arguments: tools_1.LlmToolKit.deserialize(call.function.arguments), }, }, approvalState: config.tools?.getTool(call.function.name)?.requiresConfirmation ? "requiresApproval" : "noApprovalRequired", executionState: "pending", result: null, error: null, }; }); const meta = { model, provider, temperature, durationMs, inputTokens, outputTokens, }; if (_toolCalls.length > 0) { yield { type: "response", role: "assistant_with_tools", content, toolCalls, meta, }; } else { yield { type: "response", role: "assistant", content, meta, }; } } async getAvailableModels() { const models = await this.client.models.list(); return models.data?.map((model) => model.id) ?? []; } async createEmbedding(model, text) { const response = await this.client.embeddings.create({ model, inputs: text, }); if (!response || !response.data || !response.data || response.data.length === 0 || !response.data[0].embedding) { throw new Error("Failed to create embedding"); } return response.data[0].embedding; } } exports.MistralProvider = MistralProvider;