jorel
Version:
The easiest way to use LLMs, including streams, images, documents, tools and various agent scenarios.
202 lines (201 loc) • 8.22 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.MistralProvider = void 0;
const mistralai_1 = require("@mistralai/mistralai");
const shared_1 = require("../../shared");
const providers_1 = require("../../providers");
const tools_1 = require("../../tools");
const convert_llm_message_1 = require("./convert-llm-message");
/** Provides access to OpenAI and other compatible services */
class MistralProvider {
constructor({ apiKey } = {}) {
this.name = "mistral";
this.client = new mistralai_1.Mistral({
apiKey: apiKey ?? process.env.MISTRAL_API_KEY,
});
}
async generateResponse(model, messages, config = {}) {
const start = Date.now();
const temperature = config.temperature ?? undefined;
const response = await this.client.chat.complete({
model,
messages: await (0, convert_llm_message_1.convertLlmMessagesToMistralMessages)(messages),
temperature,
responseFormat: (0, providers_1.jsonResponseToOpenAi)(config.json),
maxTokens: config.maxTokens,
toolChoice: (0, providers_1.toolChoiceToOpenAi)(config.toolChoice),
tools: config.tools?.asLlmFunctions?.map((f) => ({
type: "function",
function: {
name: f.function.name,
description: f.function.description,
parameters: {
type: f.function.parameters?.type ?? "object",
properties: f.function.parameters?.properties ?? {},
required: f.function.parameters?.required ?? [],
},
},
})),
});
const durationMs = Date.now() - start;
const inputTokens = response.usage?.promptTokens;
const outputTokens = response.usage?.completionTokens;
const message = response.choices ? (0, shared_1.firstEntry)(response.choices)?.message : undefined;
const content = Array.isArray(message?.content)
? message.content.map((c) => (c.type === "text" ? c.text : "")).join("")
: (message?.content ?? null);
const toolCalls = message?.toolCalls?.map((call) => {
return {
id: (0, shared_1.generateUniqueId)(),
request: {
id: call.id ?? (0, shared_1.generateUniqueId)(),
function: {
name: call.function.name,
arguments: typeof call.function.arguments == "string"
? tools_1.LlmToolKit.deserialize(call.function.arguments)
: call.function.arguments,
},
},
approvalState: config.tools?.getTool(call.function.name)?.requiresConfirmation
? "requiresApproval"
: "noApprovalRequired",
executionState: "pending",
result: null,
error: null,
};
});
const provider = this.name;
return {
...(0, providers_1.generateAssistantMessage)(content, toolCalls),
meta: {
model,
provider,
temperature,
durationMs,
inputTokens,
outputTokens,
},
};
}
async *generateResponseStream(model, messages, config = {}) {
const start = Date.now();
const temperature = config.temperature ?? undefined;
const response = await this.client.chat.stream({
model,
messages: await (0, convert_llm_message_1.convertLlmMessagesToMistralMessages)(messages),
temperature,
responseFormat: (0, providers_1.jsonResponseToOpenAi)(config.json),
maxTokens: config.maxTokens,
stream: true,
tools: config.tools?.asLlmFunctions?.map((f) => ({
type: "function",
function: {
name: f.function.name,
description: f.function.description,
parameters: {
type: f.function.parameters?.type ?? "object",
properties: f.function.parameters?.properties ?? {},
required: f.function.parameters?.required ?? [],
},
},
})),
toolChoice: (0, providers_1.toolChoiceToOpenAi)(config.toolChoice),
});
let inputTokens;
let outputTokens;
const _toolCalls = [];
let content = "";
for await (const chunk of response) {
const delta = (0, shared_1.firstEntry)(chunk.data.choices)?.delta;
if (delta?.content) {
content += delta.content;
yield {
type: "chunk",
content: typeof delta.content === "string"
? delta.content
: delta.content.map((c) => (c.type === "text" ? c.text : "")).join(""),
};
}
if (delta?.toolCalls) {
for (const toolCall of delta.toolCalls) {
if (toolCall.index !== undefined) {
const _toolCall = _toolCalls[toolCall.index] || { id: "", function: { name: "", arguments: "" } };
if (toolCall.id)
_toolCall.id += toolCall.id;
if (toolCall.function) {
if (toolCall.function.name)
_toolCall.function.name += toolCall.function.name;
if (toolCall.function.arguments)
_toolCall.function.arguments += toolCall.function.arguments;
}
_toolCalls[toolCall.index] = _toolCall;
}
}
}
if (chunk.data.usage) {
inputTokens = chunk.data.usage?.promptTokens;
outputTokens = chunk.data.usage?.completionTokens;
}
}
const durationMs = Date.now() - start;
const provider = this.name;
const toolCalls = _toolCalls.map((call) => {
return {
id: (0, shared_1.generateUniqueId)(),
request: {
id: call.id,
function: {
name: call.function.name,
arguments: tools_1.LlmToolKit.deserialize(call.function.arguments),
},
},
approvalState: config.tools?.getTool(call.function.name)?.requiresConfirmation
? "requiresApproval"
: "noApprovalRequired",
executionState: "pending",
result: null,
error: null,
};
});
const meta = {
model,
provider,
temperature,
durationMs,
inputTokens,
outputTokens,
};
if (_toolCalls.length > 0) {
yield {
type: "response",
role: "assistant_with_tools",
content,
toolCalls,
meta,
};
}
else {
yield {
type: "response",
role: "assistant",
content,
meta,
};
}
}
async getAvailableModels() {
const models = await this.client.models.list();
return models.data?.map((model) => model.id) ?? [];
}
async createEmbedding(model, text) {
const response = await this.client.embeddings.create({
model,
inputs: text,
});
if (!response || !response.data || !response.data || response.data.length === 0 || !response.data[0].embedding) {
throw new Error("Failed to create embedding");
}
return response.data[0].embedding;
}
}
exports.MistralProvider = MistralProvider;