@genkit-ai/vertexai
Version:
Genkit AI framework plugin for Google Cloud Vertex AI APIs including Gemini APIs, Imagen, and more.
388 lines • 12.1 kB
JavaScript
;
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name in all)
__defProp(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
var mistral_exports = {};
__export(mistral_exports, {
GENERIC_MODEL: () => GENERIC_MODEL,
KNOWN_MODELS: () => KNOWN_MODELS,
MistralConfigSchema: () => MistralConfigSchema,
defineModel: () => defineModel,
fromMistralCompletionChunk: () => fromMistralCompletionChunk,
fromMistralFinishReason: () => fromMistralFinishReason,
fromMistralResponse: () => fromMistralResponse,
isMistralModelName: () => isMistralModelName,
listActions: () => listActions,
listKnownModels: () => listKnownModels,
model: () => model,
toMistralRequest: () => toMistralRequest
});
module.exports = __toCommonJS(mistral_exports);
var import_mistralai_gcp = require("@mistralai/mistralai-gcp");
var import_components = require("@mistralai/mistralai-gcp/models/components/index.js");
var import_genkit = require("genkit");
var import_model = require("genkit/model");
var import_plugin = require("genkit/plugin");
var import_common = require("../../common/index.js");
var import_utils = require("./utils.js");
const MistralConfigSchema = import_genkit.GenerationCommonConfigSchema.extend({
// TODO: Update this with all the parameters in
// https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post.
location: import_genkit.z.string().optional(),
topP: import_genkit.z.number().describe(
import_model.GenerationCommonConfigDescriptions.topP + " The default value is 1."
).optional()
}).passthrough();
function commonRef(name, info, configSchema = MistralConfigSchema) {
return (0, import_model.modelRef)({
name: `vertex-model-garden/${name}`,
configSchema,
info: info ?? {
supports: {
multiturn: true,
media: false,
tools: true,
systemRole: true,
output: ["text"]
}
}
});
}
const GENERIC_MODEL = commonRef("mistral");
const KNOWN_MODELS = {
"mistral-medium-3": commonRef("mistral-medium-3"),
"mistral-ocr-2505": commonRef("mistral-ocr-2505"),
"mistral-small-2503": commonRef("mistral-small-2503"),
"codestral-2": commonRef("codestral-2")
};
function isMistralModelName(value) {
return !!value?.includes("tral-");
}
function model(version, options = {}) {
const name = (0, import_utils.checkModelName)(version);
return (0, import_model.modelRef)({
name: `vertex-model-garden/${name}`,
config: options,
configSchema: MistralConfigSchema,
info: {
...GENERIC_MODEL.info
}
});
}
function listActions(clientOptions) {
return [];
}
function listKnownModels(clientOptions, pluginOptions) {
return Object.keys(KNOWN_MODELS).map(
(name) => defineModel(name, clientOptions, pluginOptions)
);
}
function defineModel(name, clientOptions, pluginOptions) {
const ref = model(name);
const getClient = createClientFactory(clientOptions.projectId);
return (0, import_plugin.model)(
{
name: ref.name,
...ref.info,
configSchema: ref.configSchema
},
async (request, { streamingRequested, sendChunk }) => {
const client = getClient(
request.config?.location || clientOptions.location
);
const modelVersion = (0, import_utils.checkModelName)(ref.name);
const mistralRequest = toMistralRequest(modelVersion, request);
const mistralOptions = {
fetchOptions: {
headers: {
"X-Goog-Api-Client": (0, import_common.getGenkitClientHeader)()
}
}
};
if (!streamingRequested) {
const response = await client.chat.complete(
mistralRequest,
mistralOptions
);
return fromMistralResponse(request, response);
} else {
const stream = await client.chat.stream(mistralRequest, mistralOptions);
for await (const event of stream) {
const parts = fromMistralCompletionChunk(event.data);
if (parts.length > 0) {
sendChunk({
content: parts
});
}
}
const completeResponse = await client.chat.complete(
mistralRequest,
mistralOptions
);
return fromMistralResponse(request, completeResponse);
}
}
);
}
function createClientFactory(projectId) {
const clients = {};
return (region) => {
if (!region) {
throw new Error("Region is required to create Mistral client");
}
try {
if (!clients[region]) {
clients[region] = new import_mistralai_gcp.MistralGoogleCloud({
region,
projectId
});
}
return clients[region];
} catch (error) {
throw new Error(
`Failed to create/retrieve Mistral client for region ${region}: ${error}`
);
}
};
}
function toMistralRole(role) {
switch (role) {
case "model":
return "assistant";
case "user":
return "user";
case "tool":
return "tool";
case "system":
return "system";
default:
throw new Error(`Unknwon role ${role}`);
}
}
function toMistralToolRequest(toolRequest) {
if (!toolRequest.name) {
throw new Error("Tool name is required");
}
return {
name: toolRequest.name,
// Mistral expects arguments as either a string or object
arguments: typeof toolRequest.input === "string" ? toolRequest.input : JSON.stringify(toolRequest.input)
};
}
function toMistralRequest(model2, input) {
const messages = input.messages.map((msg) => {
if (msg.content.every((part) => part.text)) {
const content = msg.content.map((part) => part.text || "").join("");
return {
role: toMistralRole(msg.role),
content
};
}
const toolRequest = msg.content.find((part) => part.toolRequest);
if (toolRequest?.toolRequest) {
const functionCall = toMistralToolRequest(toolRequest.toolRequest);
return {
role: "assistant",
content: null,
toolCalls: [
{
id: toolRequest.toolRequest.ref,
type: import_components.ToolTypes.Function,
function: {
name: functionCall.name,
arguments: functionCall.arguments
}
}
]
};
}
const toolResponse = msg.content.find((part) => part.toolResponse);
if (toolResponse?.toolResponse) {
return {
role: "tool",
name: toolResponse.toolResponse.name,
content: JSON.stringify(toolResponse.toolResponse.output),
toolCallId: toolResponse.toolResponse.ref
// This must match the id from tool_calls
};
}
return {
role: toMistralRole(msg.role),
content: msg.content.map((part) => part.text || "").join("")
};
});
validateToolSequence(messages);
const request = {
model: model2,
messages,
maxTokens: input.config?.maxOutputTokens ?? 1024,
temperature: input.config?.temperature ?? 0.7,
...input.config?.topP && { topP: input.config.topP },
...input.config?.stopSequences && { stop: input.config.stopSequences },
...input.tools && {
tools: input.tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.inputSchema || {}
}
}))
}
};
return request;
}
function fromMistralTextPart(content) {
return {
text: content
};
}
function fromMistralToolCall(toolCall) {
if (!toolCall.function) {
throw new Error("Tool call must include a function definition");
}
return {
toolRequest: {
ref: toolCall.id,
name: toolCall.function.name,
input: typeof toolCall.function.arguments === "string" ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments
}
};
}
function fromMistralMessage(message) {
const parts = [];
if (typeof message.content === "string") {
parts.push(fromMistralTextPart(message.content));
} else if (Array.isArray(message.content)) {
message.content.forEach((chunk) => {
if (chunk.type === "text") {
parts.push(fromMistralTextPart(chunk.text));
}
});
}
if (message.toolCalls) {
message.toolCalls.forEach((toolCall) => {
parts.push(fromMistralToolCall(toolCall));
});
}
return parts;
}
function fromMistralFinishReason(reason) {
switch (reason) {
case import_components.ChatCompletionChoiceFinishReason.Stop:
return "stop";
case import_components.ChatCompletionChoiceFinishReason.Length:
case import_components.ChatCompletionChoiceFinishReason.ModelLength:
return "length";
case import_components.ChatCompletionChoiceFinishReason.Error:
return "other";
// Map generic errors to "other"
case import_components.ChatCompletionChoiceFinishReason.ToolCalls:
return "stop";
// Assuming tool calls signify a "stop" in processing
default:
return "other";
}
}
function fromMistralResponse(_input, response) {
const firstChoice = response.choices?.[0];
const contentParts = firstChoice?.message ? fromMistralMessage(firstChoice.message) : [];
const message = {
role: "model",
content: contentParts
};
return {
message,
finishReason: fromMistralFinishReason(firstChoice?.finishReason),
usage: {
inputTokens: response.usage.promptTokens,
outputTokens: response.usage.completionTokens
},
custom: {
id: response.id,
model: response.model,
created: response.created
},
raw: response
// Include the raw response for debugging or additional context
};
}
function validateToolSequence(messages) {
const toolCalls = messages.filter((m) => {
return m.role === "assistant" && m.toolCalls;
}).reduce((acc, m) => {
if (m.toolCalls) {
return [...acc, ...m.toolCalls];
}
return acc;
}, []);
const toolResponses = messages.filter(
(m) => m.role === "tool"
);
if (toolCalls.length !== toolResponses.length) {
throw new Error(
`Mismatch between tool calls (${toolCalls.length}) and responses (${toolResponses.length})`
);
}
toolResponses.forEach((response) => {
const matchingCall = toolCalls.find(
(call) => call.id === response.toolCallId
);
if (!matchingCall) {
throw new Error(
`Tool response with ID ${response.toolCallId} has no matching call`
);
}
});
}
function fromMistralCompletionChunk(chunk) {
if (!chunk.choices?.[0]?.delta) return [];
const delta = chunk.choices[0].delta;
const parts = [];
if (typeof delta.content === "string") {
parts.push({ text: delta.content });
}
if (delta.toolCalls) {
delta.toolCalls.forEach((toolCall) => {
if (!toolCall.function) return;
parts.push({
toolRequest: {
ref: toolCall.id,
name: toolCall.function.name,
input: typeof toolCall.function.arguments === "string" ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments
}
});
});
}
return parts;
}
// Annotate the CommonJS export names for ESM import in node:
0 && (module.exports = {
GENERIC_MODEL,
KNOWN_MODELS,
MistralConfigSchema,
defineModel,
fromMistralCompletionChunk,
fromMistralFinishReason,
fromMistralResponse,
isMistralModelName,
listActions,
listKnownModels,
model,
toMistralRequest
});
//# sourceMappingURL=mistral.js.map