UNPKG

rivet-plugin-mistral

Version:
807 lines (804 loc) 24.7 kB
// src/mistral.ts var mistralModels = { // Premier models "mistral-large-latest": { maxTokens: 131072, cost: { prompt: { USD: "$2", EUR: "1.8 \u20AC" }, completion: { USD: "$6", EUR: "5.4 \u20AC" } }, displayName: "Mistral Large 24.11", contextLength: 131072 }, "pixtral-large-latest": { maxTokens: 131072, cost: { prompt: { USD: "$2", EUR: "1.8 \u20AC" }, completion: { USD: "$6", EUR: "5.4 \u20AC" } }, displayName: "Pixtral Large", contextLength: 131072 }, "mistral-saba-latest": { maxTokens: 32768, cost: { prompt: { USD: "$0.2", EUR: "0.2 \u20AC" }, completion: { USD: "$0.6", EUR: "0.6 \u20AC" } }, displayName: "Mistral Saba", contextLength: 32768 }, "codestral-latest": { maxTokens: 262144, // 256k cost: { prompt: { USD: "$0.3", EUR: "0.3 \u20AC" }, completion: { USD: "$0.9", EUR: "0.9 \u20AC" } }, displayName: "Codestral", contextLength: 262144 }, "ministral-8b-latest": { maxTokens: 131072, cost: { prompt: { USD: "$0.1", EUR: "0.09 \u20AC" }, completion: { USD: "$0.1", EUR: "0.09 \u20AC" } }, displayName: "Ministral 8B 24.10", contextLength: 131072 }, "ministral-3b-latest": { maxTokens: 131072, cost: { prompt: { USD: "$0.04", EUR: "0.04 \u20AC" }, completion: { USD: "$0.04", EUR: "0.04 \u20AC" } }, displayName: "Ministral 3B 24.10", contextLength: 131072 }, "mistral-embed": { maxTokens: 8192, cost: { prompt: { USD: "$0.1", EUR: "0.09 \u20AC" }, completion: { USD: "-", EUR: "-" } }, displayName: "Mistral Embed", contextLength: 8192 }, "mistral-moderation-latest": { maxTokens: 8192, cost: { prompt: { USD: "$0.1", EUR: "0.09 \u20AC" }, completion: { USD: "-", EUR: "-" } }, displayName: "Mistral Moderation 24.11", contextLength: 8192 }, "mistral-ocr-latest": { maxTokens: 0, // Not applicable for OCR cost: { prompt: { USD: "1000 Pages / $1", EUR: "1000 Pages / 1\u20AC" }, completion: { USD: "-", EUR: "-" } }, displayName: "Mistral OCR", contextLength: 0 }, // Other models "mistral-small-latest": { maxTokens: 131072, cost: { prompt: { USD: "$0.1", EUR: "0.09 \u20AC" }, completion: { USD: "$0.3", EUR: "0.27 \u20AC" } }, displayName: "Mistral Small", contextLength: 131072 }, "open-mistral-7b": { maxTokens: 32768, cost: { prompt: { USD: "$0.25", EUR: "0.23 \u20AC" }, completion: { USD: "$0.25", EUR: "0.23 \u20AC" } }, displayName: "Open Mistral 7B", contextLength: 32768 }, "open-mixtral-8x7b": { maxTokens: 32768, cost: { prompt: { USD: "$0.7", EUR: "0.63 \u20AC" }, completion: { USD: "$0.7", EUR: "0.63 \u20AC" } }, displayName: "Open Mixtral 8x7B", contextLength: 32768 }, "open-mixtral-8x22b": { maxTokens: 64e3, cost: { prompt: { USD: "$2", EUR: "1.8 \u20AC" }, completion: { USD: "$6", EUR: "5.4 \u20AC" } }, displayName: "Open Mixtral 8x22B", contextLength: 64e3 } }; var mistralModelOptions = Object.entries(mistralModels).map(([id, { displayName }]) => ({ value: id, label: displayName })); function convertToMistralMessage(rivetMessage) { let role; switch (rivetMessage.type) { case "system": role = "system"; break; case "user": role = "user"; break; case "assistant": role = "assistant"; break; default: role = "user"; } return { role, content: typeof rivetMessage.message === "string" ? rivetMessage.message : rivetMessage.message.toString() }; } function createAssistantMessage(content) { return { type: "assistant", message: content, function_call: void 0, function_calls: [] }; } function createSystemMessage(content) { return { type: "system", message: content }; } function createUserMessage(content) { return { type: "user", message: content }; } // src/nodes/mistralChatNode.ts function mistralChatNode_default(rivet) { const nodeImpl = { create() { return { id: rivet.newId(), type: "mistralChat", title: "Mistral Chat", data: { model: "mistral-large-latest", useModelInput: false, temperature: 0.5, useTemperatureInput: false, maxTokens: 4096, useMaxTokensInput: false, topP: 1, useTopPInput: false, systemPrompt: "You are a helpful assistant.", useSystemPromptInput: true, useMessagesInput: false, useStream: true, useSafePrompt: false, useRandomSeed: false, randomSeed: void 0, currency: "USD" // Default to USD }, visualData: { x: 0, y: 0, width: 300 } }; }, getInputDefinitions(data) { const inputs = []; if (data.useModelInput) { inputs.push({ id: "model", title: "Model", dataType: "string", required: false }); } if (data.useSystemPromptInput) { inputs.push({ id: "systemPrompt", title: "System Prompt", dataType: "string", required: false }); } if (data.useTemperatureInput) { inputs.push({ dataType: "number", id: "temperature", title: "Temperature" }); } if (data.useTopPInput) { inputs.push({ dataType: "number", id: "top_p", title: "Top P" }); } if (data.useMaxTokensInput) { inputs.push({ dataType: "number", id: "maxTokens", title: "Max Tokens" }); } if (data.useMessagesInput) { inputs.push({ dataType: "chat-message[]", id: "messages", title: "Messages" }); } else { inputs.push({ dataType: ["chat-message", "chat-message[]", "string", "string[]"], id: "prompt", title: "Prompt" }); } return inputs; }, getOutputDefinitions() { return [ { id: "response", title: "Response", dataType: "string" }, { id: "message", title: "Message", dataType: "chat-message" }, { id: "messages", title: "All Messages", dataType: "chat-message[]" }, { id: "tokenDetails", title: "Token Details", dataType: "object" } ]; }, getEditors() { return [ { type: "dropdown", label: "Model", dataKey: "model", useInputToggleDataKey: "useModelInput", options: mistralModelOptions }, { type: "string", label: "System Prompt", dataKey: "systemPrompt", useInputToggleDataKey: "useSystemPromptInput" }, { type: "number", label: "Temperature", dataKey: "temperature", useInputToggleDataKey: "useTemperatureInput", min: 0, max: 1.5, step: 0.1 }, { type: "number", label: "Top P", dataKey: "topP", useInputToggleDataKey: "useTopPInput", min: 0, max: 1, step: 0.1 }, { type: "number", label: "Max Tokens", dataKey: "maxTokens", useInputToggleDataKey: "useMaxTokensInput", min: 0, step: 1 }, { type: "toggle", label: "Use Messages Input", dataKey: "useMessagesInput" }, { type: "toggle", label: "Stream Responses", dataKey: "useStream" }, { type: "toggle", label: "Use Safe Prompt", dataKey: "useSafePrompt" }, { type: "toggle", label: "Use Random Seed", dataKey: "useRandomSeed" }, { type: "number", label: "Random Seed", dataKey: "randomSeed", min: 0, step: 1 }, { type: "dropdown", label: "Currency", dataKey: "currency", options: [ { value: "USD", label: "USD ($)" }, { value: "EUR", label: "EUR (\u20AC)" } ] } ]; }, getUIData() { return { contextMenuTitle: "Mistral Chat", group: "AI/Chat (Mistral)", infoBoxBody: `Makes a call to Mistral AI's chat completion API. Supports all available Mistral models and includes various parameters for fine-tuning the response.`, infoBoxTitle: "Mistral Chat Node" }; }, getBody(data) { const modelInfo = mistralModels[data.model] || { displayName: data.model, cost: { prompt: { USD: "-", EUR: "-" }, completion: { USD: "-", EUR: "-" } } }; const promptPrice = modelInfo.cost.prompt[data.currency]; const completionPrice = modelInfo.cost.completion[data.currency]; return `Model: ${modelInfo.displayName} Temperature: ${data.temperature} Max Tokens: ${data.maxTokens} Top P: ${data.topP} ${promptPrice}/1M prompt tokens ${completionPrice}/1M completion tokens`; }, async process(data, inputs, context) { try { console.log("Starting Mistral Chat node processing..."); const apiKey = context.getPluginConfig("mistralApiKey"); if (!apiKey) { throw new Error("Mistral API key not configured. Please add your API key in the plugin configuration."); } const model = rivet.getInputOrData(data, inputs, "model", "string") ?? data.model; const temperature = rivet.getInputOrData(data, inputs, "temperature", "number") ?? data.temperature; const maxTokens = rivet.getInputOrData(data, inputs, "maxTokens", "number") ?? data.maxTokens; const topP = rivet.getInputOrData(data, inputs, "topP", "number") ?? data.topP; const systemPrompt = rivet.getInputOrData(data, inputs, "systemPrompt", "string") ?? data.systemPrompt; let messages = []; if (systemPrompt?.trim()) { messages.push({ role: "system", content: systemPrompt }); } if (data.useMessagesInput) { const inputMessages = rivet.coerceType(inputs["messages"], "chat-message[]"); if (!inputMessages) { throw new Error("Invalid messages input format"); } messages.push(...inputMessages.map((msg) => ({ role: msg.type, content: typeof msg.message === "string" ? msg.message : msg.message.toString() }))); } else { const promptInput = inputs["prompt"]; if (!promptInput) { throw new Error("No prompt provided"); } let userMessages = []; if (promptInput.type === "chat-message") { userMessages = [promptInput.value]; } else if (promptInput.type === "chat-message[]") { userMessages = promptInput.value; } else if (promptInput.type === "string") { userMessages = [createUserMessage(promptInput.value)]; } else if (promptInput.type === "string[]") { userMessages = promptInput.value.map((v) => createUserMessage(v)); } else { throw new Error(`Invalid prompt format: ${promptInput.type}`); } messages.push(...userMessages.map(convertToMistralMessage)); } const requestBody = { model, messages, temperature, max_tokens: maxTokens, top_p: topP, stream: data.useStream, safe_prompt: data.useSafePrompt, random_seed: data.useRandomSeed ? data.randomSeed : void 0 }; const response = await fetch("https://api.mistral.ai/v1/chat/completions", { method: "POST", headers: { "Content-Type": "application/json", "Authorization": `Bearer ${apiKey}` }, body: JSON.stringify(requestBody) }); if (!response.ok) { const errorText = await response.text(); console.error("Mistral API error:", response.status, errorText); throw new Error(`Mistral API error: ${response.status} - ${errorText}`); } const output = {}; if (data.useStream) { const reader = response.body?.getReader(); const decoder = new TextDecoder(); let buffer = ""; const responseParts = []; const allChunks = []; let tokenUsageFound = false; let tokenUsage = { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }; if (!reader) { throw new Error("No response body"); } while (true) { const { done, value } = await reader.read(); if (done) break; buffer += decoder.decode(value, { stream: true }); const lines = buffer.split("\n"); buffer = lines.pop() || ""; for (const line of lines) { if (line.trim() === "") continue; if (line.startsWith("data: ")) { const dataContent = line.slice(6); if (dataContent === "[DONE]") continue; allChunks.push(dataContent); try { const jsonData = JSON.parse(dataContent); if (jsonData.usage && jsonData.usage.total_tokens) { console.log("Found token usage in streaming response:", jsonData.usage); tokenUsageFound = true; tokenUsage = jsonData.usage; } const content = jsonData.choices[0]?.delta?.content; if (content) { responseParts.push(content); const currentResponse = responseParts.join(""); const assistantMessage = createAssistantMessage(currentResponse); output["response"] = { type: "string", value: currentResponse }; output["message"] = { type: "chat-message", value: assistantMessage }; output["messages"] = { type: "chat-message[]", value: [ ...messages.map((m) => { if (m.role === "system") return createSystemMessage(m.content); if (m.role === "user") return createUserMessage(m.content); return createAssistantMessage(m.content); }), assistantMessage ] }; if (tokenUsageFound) { const modelInfo = mistralModels[model] || { cost: { prompt: { USD: "$0", EUR: "0 \u20AC" }, completion: { USD: "$0", EUR: "0 \u20AC" } } }; const promptPriceStr = modelInfo.cost.prompt[data.currency]; const completionPriceStr = modelInfo.cost.completion[data.currency]; const promptCostPerMillion = parseFloat(promptPriceStr.replace(/[^0-9.]/g, "")); const completionCostPerMillion = completionPriceStr === "-" ? 0 : parseFloat(completionPriceStr.replace(/[^0-9.]/g, "")); const promptCost = tokenUsage.prompt_tokens / 1e6 * promptCostPerMillion; const completionCost = tokenUsage.completion_tokens / 1e6 * completionCostPerMillion; const totalCostDollars = promptCost + completionCost; const totalCostCents = Number((totalCostDollars * 100).toFixed(4)); output["tokenDetails"] = { type: "object", value: { prompt: tokenUsage.prompt_tokens, completion: tokenUsage.completion_tokens, total: tokenUsage.total_tokens, estimatedCostCents: totalCostCents, currency: data.currency } }; } context.onPartialOutputs?.(output); } } catch (e) { console.error("Error parsing JSON from stream:", e); } } } } if (!tokenUsageFound) { console.log("Analyzing all chunks for token usage information..."); console.log("All received chunks:", allChunks); for (const chunk of allChunks) { try { const json = JSON.parse(chunk); if (json.usage && json.usage.total_tokens) { console.log("Found token usage in chunk analysis:", json.usage); tokenUsage = json.usage; tokenUsageFound = true; break; } } catch (e) { continue; } } } if (!tokenUsageFound) { console.log("Token usage not found in streaming response, using estimates"); const fullResponse = responseParts.join(""); const estimatedCompletionTokens = Math.ceil(fullResponse.length / 4); tokenUsage = { prompt_tokens: 0, // Unknown completion_tokens: estimatedCompletionTokens, total_tokens: estimatedCompletionTokens // Incomplete total }; output["tokenDetails"] = { type: "object", value: { note: "Token details estimated - actual counts not available in streaming mode", prompt: tokenUsage.prompt_tokens, completion: tokenUsage.completion_tokens, total: tokenUsage.total_tokens, estimatedCostCents: 0, // Can't calculate accurately without prompt tokens currency: data.currency } }; } else { const modelInfo = mistralModels[model] || { cost: { prompt: { USD: "$0", EUR: "0 \u20AC" }, completion: { USD: "$0", EUR: "0 \u20AC" } } }; const promptPriceStr = modelInfo.cost.prompt[data.currency]; const completionPriceStr = modelInfo.cost.completion[data.currency]; const promptCostPerMillion = parseFloat(promptPriceStr.replace(/[^0-9.]/g, "")); const completionCostPerMillion = completionPriceStr === "-" ? 0 : parseFloat(completionPriceStr.replace(/[^0-9.]/g, "")); const promptCost = tokenUsage.prompt_tokens / 1e6 * promptCostPerMillion; const completionCost = tokenUsage.completion_tokens / 1e6 * completionCostPerMillion; const totalCostDollars = promptCost + completionCost; const totalCostCents = Number((totalCostDollars * 100).toFixed(4)); output["tokenDetails"] = { type: "object", value: { prompt: tokenUsage.prompt_tokens, completion: tokenUsage.completion_tokens, total: tokenUsage.total_tokens, estimatedCostCents: totalCostCents, currency: data.currency } }; } } else { const json = await response.json(); const content = json.choices[0]?.message?.content; const promptTokens = json.usage.prompt_tokens; const completionTokens = json.usage.completion_tokens; const totalTokens = json.usage.total_tokens; const modelInfo = mistralModels[model] || { cost: { prompt: { USD: "$0", EUR: "0 \u20AC" }, completion: { USD: "$0", EUR: "0 \u20AC" } } }; const promptPriceStr = modelInfo.cost.prompt[data.currency]; const completionPriceStr = modelInfo.cost.completion[data.currency]; const promptCostPerMillion = parseFloat(promptPriceStr.replace(/[^0-9.]/g, "")); const completionCostPerMillion = completionPriceStr === "-" ? 0 : parseFloat(completionPriceStr.replace(/[^0-9.]/g, "")); let totalCostCents = 0; if (model === "mistral-ocr-latest") { totalCostCents = 0; } else { const promptCost = promptTokens / 1e6 * promptCostPerMillion; const completionCost = completionTokens / 1e6 * completionCostPerMillion; const totalCostDollars = promptCost + completionCost; totalCostCents = Number((totalCostDollars * 100).toFixed(4)); } const currencyLabel = data.currency === "USD" ? "cents" : "euro cents"; console.log(`Mistral API call: Model: ${model} Prompt tokens: ${promptTokens} Completion tokens: ${completionTokens} Total tokens: ${totalTokens} Estimated cost: ${totalCostCents} ${currencyLabel}`); const assistantMessage = createAssistantMessage(content); output["response"] = { type: "string", value: content }; output["message"] = { type: "chat-message", value: assistantMessage }; output["messages"] = { type: "chat-message[]", value: [ ...messages.map((m) => { if (m.role === "system") return createSystemMessage(m.content); if (m.role === "user") return createUserMessage(m.content); return createAssistantMessage(m.content); }), assistantMessage ] }; output["tokenDetails"] = { type: "object", value: { prompt: promptTokens, completion: completionTokens, total: totalTokens, estimatedCostCents: totalCostCents, currency: data.currency } }; } return output; } catch (error) { console.error("Error in Mistral Chat node:", error); throw error; } } }; return rivet.pluginNodeDefinition(nodeImpl, "Mistral Chat"); } // src/index.ts var initializer = (rivet) => { console.log("Initializing Mistral plugin..."); const node = mistralChatNode_default(rivet); console.log("Created node:", node); const plugin = { id: "rivet-plugin-mistral", name: "Mistral AI", configSpec: { mistralApiKey: { type: "secret", label: "Mistral API Key", description: "The API key for accessing Mistral AI.", pullEnvironmentVariable: "MISTRAL_API_KEY", helperText: "You may also set the MISTRAL_API_KEY environment variable." } }, contextMenuGroups: [ { id: "ai-chat-mistral", label: "AI/Chat (Mistral)" } ], register: (register) => { console.log("Registering Mistral node..."); register(node); } }; return plugin; }; var src_default = initializer; export { src_default as default };