UNPKG

@genkit-ai/vertexai

Version:

Genkit AI framework plugin for Google Cloud Vertex AI APIs including Gemini APIs, Imagen, and more.

373 lines 10.7 kB
import { MistralGoogleCloud } from "@mistralai/mistralai-gcp"; import { ChatCompletionChoiceFinishReason, ToolTypes } from "@mistralai/mistralai-gcp/models/components"; import { GENKIT_CLIENT_HEADER, GenerationCommonConfigSchema, z } from "genkit"; import { modelRef } from "genkit/model"; const MistralConfigSchema = GenerationCommonConfigSchema.extend({ // TODO: Update this with all the parameters in // https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post. location: z.string().optional(), maxOutputTokens: z.number().optional(), temperature: z.number().optional(), // TODO: is this supported? // topK: z.number().optional(), topP: z.number().optional(), stopSequences: z.array(z.string()).optional() }); const mistralLarge = modelRef({ name: "vertexai/mistral-large", info: { label: "Vertex AI Model Garden - Mistral Large", versions: ["mistral-large-2411", "mistral-large-2407"], supports: { multiturn: true, media: false, tools: true, systemRole: true, output: ["text"] } }, configSchema: MistralConfigSchema }); const mistralNemo = modelRef({ name: "vertexai/mistral-nemo", info: { label: "Vertex AI Model Garden - Mistral Nemo", versions: ["mistral-nemo-2407"], supports: { multiturn: true, media: false, tools: false, systemRole: true, output: ["text"] } }, configSchema: MistralConfigSchema }); const codestral = modelRef({ name: "vertexai/codestral", info: { label: "Vertex AI Model Garden - Codestral", versions: ["codestral-2405"], supports: { multiturn: true, media: false, tools: false, systemRole: true, output: ["text"] } }, configSchema: MistralConfigSchema }); const SUPPORTED_MISTRAL_MODELS = { "mistral-large": mistralLarge, "mistral-nemo": mistralNemo, codestral }; function toMistralRole(role) { switch (role) { case "model": return "assistant"; case "user": return "user"; case "tool": return "tool"; case "system": return "system"; default: throw new Error(`Unknwon role ${role}`); } } function toMistralToolRequest(toolRequest) { if (!toolRequest.name) { throw new Error("Tool name is required"); } return { name: toolRequest.name, // Mistral expects arguments as either a string or object arguments: typeof toolRequest.input === "string" ? toolRequest.input : JSON.stringify(toolRequest.input) }; } function toMistralRequest(model, input) { const messages = input.messages.map((msg) => { if (msg.content.every((part) => part.text)) { const content = msg.content.map((part) => part.text || "").join(""); return { role: toMistralRole(msg.role), content }; } const toolRequest = msg.content.find((part) => part.toolRequest); if (toolRequest?.toolRequest) { const functionCall = toMistralToolRequest(toolRequest.toolRequest); return { role: "assistant", content: null, toolCalls: [ { id: toolRequest.toolRequest.ref, type: ToolTypes.Function, function: { name: functionCall.name, arguments: functionCall.arguments } } ] }; } const toolResponse = msg.content.find((part) => part.toolResponse); if (toolResponse?.toolResponse) { return { role: "tool", name: toolResponse.toolResponse.name, content: JSON.stringify(toolResponse.toolResponse.output), toolCallId: toolResponse.toolResponse.ref // This must match the id from tool_calls }; } return { role: toMistralRole(msg.role), content: msg.content.map((part) => part.text || "").join("") }; }); validateToolSequence(messages); const request = { model, messages, maxTokens: input.config?.maxOutputTokens ?? 1024, temperature: input.config?.temperature ?? 0.7, ...input.config?.topP && { topP: input.config.topP }, ...input.config?.stopSequences && { stop: input.config.stopSequences }, ...input.tools && { tools: input.tools.map((tool) => ({ type: "function", function: { name: tool.name, description: tool.description, parameters: tool.inputSchema || {} } })) } }; return request; } function fromMistralTextPart(content) { return { text: content }; } function fromMistralToolCall(toolCall) { if (!toolCall.function) { throw new Error("Tool call must include a function definition"); } return { toolRequest: { ref: toolCall.id, name: toolCall.function.name, input: typeof toolCall.function.arguments === "string" ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments } }; } function fromMistralMessage(message) { const parts = []; if (typeof message.content === "string") { parts.push(fromMistralTextPart(message.content)); } else if (Array.isArray(message.content)) { message.content.forEach((chunk) => { if (chunk.type === "text") { parts.push(fromMistralTextPart(chunk.text)); } }); } if (message.toolCalls) { message.toolCalls.forEach((toolCall) => { parts.push(fromMistralToolCall(toolCall)); }); } return parts; } function fromMistralFinishReason(reason) { switch (reason) { case ChatCompletionChoiceFinishReason.Stop: return "stop"; case ChatCompletionChoiceFinishReason.Length: case ChatCompletionChoiceFinishReason.ModelLength: return "length"; case ChatCompletionChoiceFinishReason.Error: return "other"; // Map generic errors to "other" case ChatCompletionChoiceFinishReason.ToolCalls: return "stop"; // Assuming tool calls signify a "stop" in processing default: return "other"; } } function fromMistralResponse(_input, response) { const firstChoice = response.choices?.[0]; const contentParts = firstChoice?.message ? fromMistralMessage(firstChoice.message) : []; const message = { role: "model", content: contentParts }; return { message, finishReason: fromMistralFinishReason(firstChoice?.finishReason), usage: { inputTokens: response.usage.promptTokens, outputTokens: response.usage.completionTokens }, custom: { id: response.id, model: response.model, created: response.created }, raw: response // Include the raw response for debugging or additional context }; } function mistralModel(ai, modelName, projectId, region) { const getClient = createClientFactory(projectId); const model = SUPPORTED_MISTRAL_MODELS[modelName]; if (!model) { throw new Error(`Unsupported Mistral model name ${modelName}`); } return ai.defineModel( { name: model.name, label: model.info?.label, configSchema: MistralConfigSchema, supports: model.info?.supports, versions: model.info?.versions }, async (input, sendChunk) => { const client = getClient(input.config?.location || region); const versionedModel = input.config?.version ?? model.info?.versions?.[0] ?? model.name; if (!sendChunk) { const mistralRequest = toMistralRequest(versionedModel, input); const response = await client.chat.complete(mistralRequest, { fetchOptions: { headers: { "X-Goog-Api-Client": GENKIT_CLIENT_HEADER } } }); return fromMistralResponse(input, response); } else { const mistralRequest = toMistralRequest(versionedModel, input); const stream = await client.chat.stream(mistralRequest, { fetchOptions: { headers: { "X-Goog-Api-Client": GENKIT_CLIENT_HEADER } } }); for await (const event of stream) { const parts = fromMistralCompletionChunk(event.data); if (parts.length > 0) { sendChunk({ content: parts }); } } const completeResponse = await client.chat.complete(mistralRequest, { fetchOptions: { headers: { "X-Goog-Api-Client": GENKIT_CLIENT_HEADER } } }); return fromMistralResponse(input, completeResponse); } } ); } function createClientFactory(projectId) { const clients = {}; return (region) => { if (!region) { throw new Error("Region is required to create Mistral client"); } try { if (!clients[region]) { clients[region] = new MistralGoogleCloud({ region, projectId }); } return clients[region]; } catch (error) { throw new Error( `Failed to create/retrieve Mistral client for region ${region}: ${error}` ); } }; } function validateToolSequence(messages) { const toolCalls = messages.filter((m) => { return m.role === "assistant" && m.toolCalls; }).reduce((acc, m) => { if (m.toolCalls) { return [...acc, ...m.toolCalls]; } return acc; }, []); const toolResponses = messages.filter( (m) => m.role === "tool" ); if (toolCalls.length !== toolResponses.length) { throw new Error( `Mismatch between tool calls (${toolCalls.length}) and responses (${toolResponses.length})` ); } toolResponses.forEach((response) => { const matchingCall = toolCalls.find( (call) => call.id === response.toolCallId ); if (!matchingCall) { throw new Error( `Tool response with ID ${response.toolCallId} has no matching call` ); } }); } function fromMistralCompletionChunk(chunk) { if (!chunk.choices?.[0]?.delta) return []; const delta = chunk.choices[0].delta; const parts = []; if (typeof delta.content === "string") { parts.push({ text: delta.content }); } if (delta.toolCalls) { delta.toolCalls.forEach((toolCall) => { if (!toolCall.function) return; parts.push({ toolRequest: { ref: toolCall.id, name: toolCall.function.name, input: typeof toolCall.function.arguments === "string" ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments } }); }); } return parts; } export { MistralConfigSchema, SUPPORTED_MISTRAL_MODELS, codestral, fromMistralCompletionChunk, fromMistralFinishReason, fromMistralResponse, mistralLarge, mistralModel, mistralNemo, toMistralRequest }; //# sourceMappingURL=mistral.mjs.map