UNPKG

@genkit-ai/vertexai

Version:

Genkit AI framework plugin for Google Cloud Vertex AI APIs including Gemini APIs, Imagen, and more.

362 lines 10.4 kB
import { MistralGoogleCloud } from "@mistralai/mistralai-gcp"; import { ChatCompletionChoiceFinishReason, ToolTypes } from "@mistralai/mistralai-gcp/models/components/index.js"; import { GenerationCommonConfigSchema, z } from "genkit"; import { GenerationCommonConfigDescriptions, modelRef } from "genkit/model"; import { model as pluginModel } from "genkit/plugin"; import { getGenkitClientHeader } from "../../common/index.mjs"; import { checkModelName } from "./utils.mjs"; const MistralConfigSchema = GenerationCommonConfigSchema.extend({ // TODO: Update this with all the parameters in // https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post. location: z.string().optional(), topP: z.number().describe( GenerationCommonConfigDescriptions.topP + " The default value is 1." ).optional() }).passthrough(); function commonRef(name, info, configSchema = MistralConfigSchema) { return modelRef({ name: `vertex-model-garden/${name}`, configSchema, info: info ?? { supports: { multiturn: true, media: false, tools: true, systemRole: true, output: ["text"] } } }); } const GENERIC_MODEL = commonRef("mistral"); const KNOWN_MODELS = { "mistral-medium-3": commonRef("mistral-medium-3"), "mistral-ocr-2505": commonRef("mistral-ocr-2505"), "mistral-small-2503": commonRef("mistral-small-2503"), "codestral-2": commonRef("codestral-2") }; function isMistralModelName(value) { return !!value?.includes("tral-"); } function model(version, options = {}) { const name = checkModelName(version); return modelRef({ name: `vertex-model-garden/${name}`, config: options, configSchema: MistralConfigSchema, info: { ...GENERIC_MODEL.info } }); } function listActions(clientOptions) { return []; } function listKnownModels(clientOptions, pluginOptions) { return Object.keys(KNOWN_MODELS).map( (name) => defineModel(name, clientOptions, pluginOptions) ); } function defineModel(name, clientOptions, pluginOptions) { const ref = model(name); const getClient = createClientFactory(clientOptions.projectId); return pluginModel( { name: ref.name, ...ref.info, configSchema: ref.configSchema }, async (request, { streamingRequested, sendChunk }) => { const client = getClient( request.config?.location || clientOptions.location ); const modelVersion = checkModelName(ref.name); const mistralRequest = toMistralRequest(modelVersion, request); const mistralOptions = { fetchOptions: { headers: { "X-Goog-Api-Client": getGenkitClientHeader() } } }; if (!streamingRequested) { const response = await client.chat.complete( mistralRequest, mistralOptions ); return fromMistralResponse(request, response); } else { const stream = await client.chat.stream(mistralRequest, mistralOptions); for await (const event of stream) { const parts = fromMistralCompletionChunk(event.data); if (parts.length > 0) { sendChunk({ content: parts }); } } const completeResponse = await client.chat.complete( mistralRequest, mistralOptions ); return fromMistralResponse(request, completeResponse); } } ); } function createClientFactory(projectId) { const clients = {}; return (region) => { if (!region) { throw new Error("Region is required to create Mistral client"); } try { if (!clients[region]) { clients[region] = new MistralGoogleCloud({ region, projectId }); } return clients[region]; } catch (error) { throw new Error( `Failed to create/retrieve Mistral client for region ${region}: ${error}` ); } }; } function toMistralRole(role) { switch (role) { case "model": return "assistant"; case "user": return "user"; case "tool": return "tool"; case "system": return "system"; default: throw new Error(`Unknwon role ${role}`); } } function toMistralToolRequest(toolRequest) { if (!toolRequest.name) { throw new Error("Tool name is required"); } return { name: toolRequest.name, // Mistral expects arguments as either a string or object arguments: typeof toolRequest.input === "string" ? toolRequest.input : JSON.stringify(toolRequest.input) }; } function toMistralRequest(model2, input) { const messages = input.messages.map((msg) => { if (msg.content.every((part) => part.text)) { const content = msg.content.map((part) => part.text || "").join(""); return { role: toMistralRole(msg.role), content }; } const toolRequest = msg.content.find((part) => part.toolRequest); if (toolRequest?.toolRequest) { const functionCall = toMistralToolRequest(toolRequest.toolRequest); return { role: "assistant", content: null, toolCalls: [ { id: toolRequest.toolRequest.ref, type: ToolTypes.Function, function: { name: functionCall.name, arguments: functionCall.arguments } } ] }; } const toolResponse = msg.content.find((part) => part.toolResponse); if (toolResponse?.toolResponse) { return { role: "tool", name: toolResponse.toolResponse.name, content: JSON.stringify(toolResponse.toolResponse.output), toolCallId: toolResponse.toolResponse.ref // This must match the id from tool_calls }; } return { role: toMistralRole(msg.role), content: msg.content.map((part) => part.text || "").join("") }; }); validateToolSequence(messages); const request = { model: model2, messages, maxTokens: input.config?.maxOutputTokens ?? 1024, temperature: input.config?.temperature ?? 0.7, ...input.config?.topP && { topP: input.config.topP }, ...input.config?.stopSequences && { stop: input.config.stopSequences }, ...input.tools && { tools: input.tools.map((tool) => ({ type: "function", function: { name: tool.name, description: tool.description, parameters: tool.inputSchema || {} } })) } }; return request; } function fromMistralTextPart(content) { return { text: content }; } function fromMistralToolCall(toolCall) { if (!toolCall.function) { throw new Error("Tool call must include a function definition"); } return { toolRequest: { ref: toolCall.id, name: toolCall.function.name, input: typeof toolCall.function.arguments === "string" ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments } }; } function fromMistralMessage(message) { const parts = []; if (typeof message.content === "string") { parts.push(fromMistralTextPart(message.content)); } else if (Array.isArray(message.content)) { message.content.forEach((chunk) => { if (chunk.type === "text") { parts.push(fromMistralTextPart(chunk.text)); } }); } if (message.toolCalls) { message.toolCalls.forEach((toolCall) => { parts.push(fromMistralToolCall(toolCall)); }); } return parts; } function fromMistralFinishReason(reason) { switch (reason) { case ChatCompletionChoiceFinishReason.Stop: return "stop"; case ChatCompletionChoiceFinishReason.Length: case ChatCompletionChoiceFinishReason.ModelLength: return "length"; case ChatCompletionChoiceFinishReason.Error: return "other"; // Map generic errors to "other" case ChatCompletionChoiceFinishReason.ToolCalls: return "stop"; // Assuming tool calls signify a "stop" in processing default: return "other"; } } function fromMistralResponse(_input, response) { const firstChoice = response.choices?.[0]; const contentParts = firstChoice?.message ? fromMistralMessage(firstChoice.message) : []; const message = { role: "model", content: contentParts }; return { message, finishReason: fromMistralFinishReason(firstChoice?.finishReason), usage: { inputTokens: response.usage.promptTokens, outputTokens: response.usage.completionTokens }, custom: { id: response.id, model: response.model, created: response.created }, raw: response // Include the raw response for debugging or additional context }; } function validateToolSequence(messages) { const toolCalls = messages.filter((m) => { return m.role === "assistant" && m.toolCalls; }).reduce((acc, m) => { if (m.toolCalls) { return [...acc, ...m.toolCalls]; } return acc; }, []); const toolResponses = messages.filter( (m) => m.role === "tool" ); if (toolCalls.length !== toolResponses.length) { throw new Error( `Mismatch between tool calls (${toolCalls.length}) and responses (${toolResponses.length})` ); } toolResponses.forEach((response) => { const matchingCall = toolCalls.find( (call) => call.id === response.toolCallId ); if (!matchingCall) { throw new Error( `Tool response with ID ${response.toolCallId} has no matching call` ); } }); } function fromMistralCompletionChunk(chunk) { if (!chunk.choices?.[0]?.delta) return []; const delta = chunk.choices[0].delta; const parts = []; if (typeof delta.content === "string") { parts.push({ text: delta.content }); } if (delta.toolCalls) { delta.toolCalls.forEach((toolCall) => { if (!toolCall.function) return; parts.push({ toolRequest: { ref: toolCall.id, name: toolCall.function.name, input: typeof toolCall.function.arguments === "string" ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments } }); }); } return parts; } export { GENERIC_MODEL, KNOWN_MODELS, MistralConfigSchema, defineModel, fromMistralCompletionChunk, fromMistralFinishReason, fromMistralResponse, isMistralModelName, listActions, listKnownModels, model, toMistralRequest }; //# sourceMappingURL=mistral.mjs.map