@genkit-ai/vertexai
Version:
Genkit AI framework plugin for Google Cloud Vertex AI APIs including Gemini APIs, Imagen, and more.
373 lines • 10.7 kB
JavaScript
import { MistralGoogleCloud } from "@mistralai/mistralai-gcp";
import {
ChatCompletionChoiceFinishReason,
ToolTypes
} from "@mistralai/mistralai-gcp/models/components";
import {
GENKIT_CLIENT_HEADER,
GenerationCommonConfigSchema,
z
} from "genkit";
import { modelRef } from "genkit/model";
const MistralConfigSchema = GenerationCommonConfigSchema.extend({
// TODO: Update this with all the parameters in
// https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post.
location: z.string().optional(),
maxOutputTokens: z.number().optional(),
temperature: z.number().optional(),
// TODO: is this supported?
// topK: z.number().optional(),
topP: z.number().optional(),
stopSequences: z.array(z.string()).optional()
});
const mistralLarge = modelRef({
name: "vertexai/mistral-large",
info: {
label: "Vertex AI Model Garden - Mistral Large",
versions: ["mistral-large-2411", "mistral-large-2407"],
supports: {
multiturn: true,
media: false,
tools: true,
systemRole: true,
output: ["text"]
}
},
configSchema: MistralConfigSchema
});
const mistralNemo = modelRef({
name: "vertexai/mistral-nemo",
info: {
label: "Vertex AI Model Garden - Mistral Nemo",
versions: ["mistral-nemo-2407"],
supports: {
multiturn: true,
media: false,
tools: false,
systemRole: true,
output: ["text"]
}
},
configSchema: MistralConfigSchema
});
const codestral = modelRef({
name: "vertexai/codestral",
info: {
label: "Vertex AI Model Garden - Codestral",
versions: ["codestral-2405"],
supports: {
multiturn: true,
media: false,
tools: false,
systemRole: true,
output: ["text"]
}
},
configSchema: MistralConfigSchema
});
const SUPPORTED_MISTRAL_MODELS = {
"mistral-large": mistralLarge,
"mistral-nemo": mistralNemo,
codestral
};
function toMistralRole(role) {
switch (role) {
case "model":
return "assistant";
case "user":
return "user";
case "tool":
return "tool";
case "system":
return "system";
default:
throw new Error(`Unknwon role ${role}`);
}
}
function toMistralToolRequest(toolRequest) {
if (!toolRequest.name) {
throw new Error("Tool name is required");
}
return {
name: toolRequest.name,
// Mistral expects arguments as either a string or object
arguments: typeof toolRequest.input === "string" ? toolRequest.input : JSON.stringify(toolRequest.input)
};
}
function toMistralRequest(model, input) {
const messages = input.messages.map((msg) => {
if (msg.content.every((part) => part.text)) {
const content = msg.content.map((part) => part.text || "").join("");
return {
role: toMistralRole(msg.role),
content
};
}
const toolRequest = msg.content.find((part) => part.toolRequest);
if (toolRequest?.toolRequest) {
const functionCall = toMistralToolRequest(toolRequest.toolRequest);
return {
role: "assistant",
content: null,
toolCalls: [
{
id: toolRequest.toolRequest.ref,
type: ToolTypes.Function,
function: {
name: functionCall.name,
arguments: functionCall.arguments
}
}
]
};
}
const toolResponse = msg.content.find((part) => part.toolResponse);
if (toolResponse?.toolResponse) {
return {
role: "tool",
name: toolResponse.toolResponse.name,
content: JSON.stringify(toolResponse.toolResponse.output),
toolCallId: toolResponse.toolResponse.ref
// This must match the id from tool_calls
};
}
return {
role: toMistralRole(msg.role),
content: msg.content.map((part) => part.text || "").join("")
};
});
validateToolSequence(messages);
const request = {
model,
messages,
maxTokens: input.config?.maxOutputTokens ?? 1024,
temperature: input.config?.temperature ?? 0.7,
...input.config?.topP && { topP: input.config.topP },
...input.config?.stopSequences && { stop: input.config.stopSequences },
...input.tools && {
tools: input.tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.inputSchema || {}
}
}))
}
};
return request;
}
function fromMistralTextPart(content) {
return {
text: content
};
}
function fromMistralToolCall(toolCall) {
if (!toolCall.function) {
throw new Error("Tool call must include a function definition");
}
return {
toolRequest: {
ref: toolCall.id,
name: toolCall.function.name,
input: typeof toolCall.function.arguments === "string" ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments
}
};
}
function fromMistralMessage(message) {
const parts = [];
if (typeof message.content === "string") {
parts.push(fromMistralTextPart(message.content));
} else if (Array.isArray(message.content)) {
message.content.forEach((chunk) => {
if (chunk.type === "text") {
parts.push(fromMistralTextPart(chunk.text));
}
});
}
if (message.toolCalls) {
message.toolCalls.forEach((toolCall) => {
parts.push(fromMistralToolCall(toolCall));
});
}
return parts;
}
function fromMistralFinishReason(reason) {
switch (reason) {
case ChatCompletionChoiceFinishReason.Stop:
return "stop";
case ChatCompletionChoiceFinishReason.Length:
case ChatCompletionChoiceFinishReason.ModelLength:
return "length";
case ChatCompletionChoiceFinishReason.Error:
return "other";
// Map generic errors to "other"
case ChatCompletionChoiceFinishReason.ToolCalls:
return "stop";
// Assuming tool calls signify a "stop" in processing
default:
return "other";
}
}
function fromMistralResponse(_input, response) {
const firstChoice = response.choices?.[0];
const contentParts = firstChoice?.message ? fromMistralMessage(firstChoice.message) : [];
const message = {
role: "model",
content: contentParts
};
return {
message,
finishReason: fromMistralFinishReason(firstChoice?.finishReason),
usage: {
inputTokens: response.usage.promptTokens,
outputTokens: response.usage.completionTokens
},
custom: {
id: response.id,
model: response.model,
created: response.created
},
raw: response
// Include the raw response for debugging or additional context
};
}
function mistralModel(ai, modelName, projectId, region) {
const getClient = createClientFactory(projectId);
const model = SUPPORTED_MISTRAL_MODELS[modelName];
if (!model) {
throw new Error(`Unsupported Mistral model name ${modelName}`);
}
return ai.defineModel(
{
name: model.name,
label: model.info?.label,
configSchema: MistralConfigSchema,
supports: model.info?.supports,
versions: model.info?.versions
},
async (input, sendChunk) => {
const client = getClient(input.config?.location || region);
const versionedModel = input.config?.version ?? model.info?.versions?.[0] ?? model.name;
if (!sendChunk) {
const mistralRequest = toMistralRequest(versionedModel, input);
const response = await client.chat.complete(mistralRequest, {
fetchOptions: {
headers: {
"X-Goog-Api-Client": GENKIT_CLIENT_HEADER
}
}
});
return fromMistralResponse(input, response);
} else {
const mistralRequest = toMistralRequest(versionedModel, input);
const stream = await client.chat.stream(mistralRequest, {
fetchOptions: {
headers: {
"X-Goog-Api-Client": GENKIT_CLIENT_HEADER
}
}
});
for await (const event of stream) {
const parts = fromMistralCompletionChunk(event.data);
if (parts.length > 0) {
sendChunk({
content: parts
});
}
}
const completeResponse = await client.chat.complete(mistralRequest, {
fetchOptions: {
headers: {
"X-Goog-Api-Client": GENKIT_CLIENT_HEADER
}
}
});
return fromMistralResponse(input, completeResponse);
}
}
);
}
function createClientFactory(projectId) {
const clients = {};
return (region) => {
if (!region) {
throw new Error("Region is required to create Mistral client");
}
try {
if (!clients[region]) {
clients[region] = new MistralGoogleCloud({
region,
projectId
});
}
return clients[region];
} catch (error) {
throw new Error(
`Failed to create/retrieve Mistral client for region ${region}: ${error}`
);
}
};
}
function validateToolSequence(messages) {
const toolCalls = messages.filter((m) => {
return m.role === "assistant" && m.toolCalls;
}).reduce((acc, m) => {
if (m.toolCalls) {
return [...acc, ...m.toolCalls];
}
return acc;
}, []);
const toolResponses = messages.filter(
(m) => m.role === "tool"
);
if (toolCalls.length !== toolResponses.length) {
throw new Error(
`Mismatch between tool calls (${toolCalls.length}) and responses (${toolResponses.length})`
);
}
toolResponses.forEach((response) => {
const matchingCall = toolCalls.find(
(call) => call.id === response.toolCallId
);
if (!matchingCall) {
throw new Error(
`Tool response with ID ${response.toolCallId} has no matching call`
);
}
});
}
function fromMistralCompletionChunk(chunk) {
if (!chunk.choices?.[0]?.delta) return [];
const delta = chunk.choices[0].delta;
const parts = [];
if (typeof delta.content === "string") {
parts.push({ text: delta.content });
}
if (delta.toolCalls) {
delta.toolCalls.forEach((toolCall) => {
if (!toolCall.function) return;
parts.push({
toolRequest: {
ref: toolCall.id,
name: toolCall.function.name,
input: typeof toolCall.function.arguments === "string" ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments
}
});
});
}
return parts;
}
export {
MistralConfigSchema,
SUPPORTED_MISTRAL_MODELS,
codestral,
fromMistralCompletionChunk,
fromMistralFinishReason,
fromMistralResponse,
mistralLarge,
mistralModel,
mistralNemo,
toMistralRequest
};
//# sourceMappingURL=mistral.mjs.map