UNPKG

@genkit-ai/vertexai

Version:

Genkit AI framework plugin for Google Cloud Vertex AI APIs including Gemini APIs, Imagen, and more.

936 lines 27.3 kB
import { FunctionCallingMode, FunctionDeclarationSchemaType, HarmBlockThreshold, HarmCategory } from "@google-cloud/vertexai"; import { ApiClient } from "@google-cloud/vertexai/build/src/resources/index.js"; import { GENKIT_CLIENT_HEADER, GenkitError, z } from "genkit"; import { GenerationCommonConfigSchema, getBasicUsageStats, modelRef } from "genkit/model"; import { downloadRequestMedia, simulateSystemPrompt } from "genkit/model/middleware"; import { runInNewSpan } from "genkit/tracing"; import { GoogleAuth } from "google-auth-library"; import { handleCacheIfNeeded } from "./context-caching/index.js"; import { extractCacheConfig } from "./context-caching/utils.js"; const SafetySettingsSchema = z.object({ category: z.nativeEnum(HarmCategory), threshold: z.nativeEnum(HarmBlockThreshold) }); const VertexRetrievalSchema = z.object({ datastore: z.object({ projectId: z.string().describe("Google Cloud Project ID.").optional(), location: z.string().describe("Google Cloud region e.g. us-central1.").optional(), dataStoreId: z.string().describe( 'The data store id, when project id and location are provided as separate options. Alternatively, the full path to the data store should be provided in the form: "projects/{project}/locations/{location}/collections/default_collection/dataStores/{data_store}".' ) }).describe("Vertex AI Search data store details"), disableAttribution: z.boolean().describe( "Disable using the search data in detecting grounding attribution. This does not affect how the result is given to the model for generation." ).optional() }); const GoogleSearchRetrievalSchema = z.object({ disableAttribution: z.boolean().describe( "Disable using the search data in detecting grounding attribution. This does not affect how the result is given to the model for generation." ).optional() }); const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ location: z.string().describe("Google Cloud region e.g. us-central1.").optional(), /** * Safety filter settings. See: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters#configurable-filters * * E.g. * * ```js * config: { * safetySettings: [ * { * category: 'HARM_CATEGORY_HATE_SPEECH', * threshold: 'BLOCK_LOW_AND_ABOVE', * }, * { * category: 'HARM_CATEGORY_DANGEROUS_CONTENT', * threshold: 'BLOCK_MEDIUM_AND_ABOVE', * }, * { * category: 'HARM_CATEGORY_HARASSMENT', * threshold: 'BLOCK_ONLY_HIGH', * }, * { * category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', * threshold: 'BLOCK_NONE', * }, * ], * } * ``` */ safetySettings: z.array(SafetySettingsSchema).describe( "Adjust how likely you are to see responses that could be harmful. Content is blocked based on the probability that it is harmful." ).optional(), /** * Vertex retrieval options. * * E.g. * * ```js * config: { * vertexRetrieval: { * datastore: { * projectId: 'your-cloud-project', * location: 'us-central1', * collection: 'your-collection', * }, * disableAttribution: true, * } * } * ``` */ vertexRetrieval: VertexRetrievalSchema.describe( "Retrieve from Vertex AI Search data store for grounding generative responses." ).optional(), /** * Google Search retrieval options. * * ```js * config: { * googleSearchRetrieval: { * disableAttribution: true, * } * } * ``` */ googleSearchRetrieval: GoogleSearchRetrievalSchema.describe( "Retrieve public web data for grounding, powered by Google Search." ).optional(), /** * Function calling options. * * E.g. forced tool call: * * ```js * config: { * functionCallingConfig: { * mode: 'ANY', * } * } * ``` */ functionCallingConfig: z.object({ mode: z.enum(["MODE_UNSPECIFIED", "AUTO", "ANY", "NONE"]).optional(), allowedFunctionNames: z.array(z.string()).optional() }).describe( "Controls how the model uses the provided tools (function declarations). With AUTO (Default) mode, the model decides whether to generate a natural language response or suggest a function call based on the prompt and context. With ANY, the model is constrained to always predict a function call and guarantee function schema adherence. With NONE, the model is prohibited from making function calls." ).optional() }); function gemini(version, options = {}) { const nearestModel = nearestGeminiModelRef(version); return modelRef({ name: `vertexai/${version}`, config: options, configSchema: GeminiConfigSchema, info: { ...nearestModel.info, // If exact suffix match for a known model, use its label, otherwise create a new label label: nearestModel.name.endsWith(version) ? nearestModel.info?.label : `Vertex AI - ${version}` } }); } function nearestGeminiModelRef(version, options = {}) { const matchingKey = longestMatchingPrefix( version, Object.keys(SUPPORTED_GEMINI_MODELS) ); if (matchingKey) { return SUPPORTED_GEMINI_MODELS[matchingKey].withConfig({ ...options, version }); } return GENERIC_GEMINI_MODEL.withConfig({ ...options, version }); } function longestMatchingPrefix(version, potentialMatches) { return potentialMatches.filter((p) => version.startsWith(p)).reduce( (longest, current) => current.length > longest.length ? current : longest, "" ); } const gemini10Pro = modelRef({ name: "vertexai/gemini-1.0-pro", info: { label: "Vertex AI - Gemini Pro", versions: ["gemini-1.0-pro-001", "gemini-1.0-pro-002"], supports: { multiturn: true, media: false, tools: true, systemRole: true, constrained: "no-tools", toolChoice: true } }, configSchema: GeminiConfigSchema }); const gemini15Pro = modelRef({ name: "vertexai/gemini-1.5-pro", info: { label: "Vertex AI - Gemini 1.5 Pro", versions: ["gemini-1.5-pro-001", "gemini-1.5-pro-002"], supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "no-tools" } }, configSchema: GeminiConfigSchema }); const gemini15Flash = modelRef({ name: "vertexai/gemini-1.5-flash", info: { label: "Vertex AI - Gemini 1.5 Flash", versions: ["gemini-1.5-flash-001", "gemini-1.5-flash-002"], supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "no-tools" } }, configSchema: GeminiConfigSchema }); const gemini20Flash001 = modelRef({ name: "vertexai/gemini-2.0-flash-001", info: { label: "Vertex AI - Gemini 2.0 Flash 001", versions: [], supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "no-tools" } }, configSchema: GeminiConfigSchema }); const gemini20Flash = modelRef({ name: "vertexai/gemini-2.0-flash", info: { label: "Vertex AI - Gemini 2.0 Flash", versions: [], supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "no-tools" } }, configSchema: GeminiConfigSchema }); const gemini20FlashLite = modelRef({ name: "vertexai/gemini-2.0-flash-lite", info: { label: "Vertex AI - Gemini 2.0 Flash Lite", versions: [], supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "no-tools" } }, configSchema: GeminiConfigSchema }); const gemini20FlashLitePreview0205 = modelRef({ name: "vertexai/gemini-2.0-flash-lite-preview-02-05", info: { label: "Vertex AI - Gemini 2.0 Flash Lite Preview 02-05", versions: [], supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "no-tools" } }, configSchema: GeminiConfigSchema }); const gemini20ProExp0205 = modelRef({ name: "vertexai/gemini-2.0-pro-exp-02-05", info: { label: "Vertex AI - Gemini 2.0 Flash Pro Experimental 02-05", versions: [], supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "no-tools" } }, configSchema: GeminiConfigSchema }); const gemini25ProExp0325 = modelRef({ name: "vertexai/gemini-2.5-pro-exp-03-25", info: { label: "Vertex AI - Gemini 2.5 Pro Experimental 03-25", versions: [], supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "no-tools" } }, configSchema: GeminiConfigSchema }); const gemini25ProPreview0325 = modelRef({ name: "vertexai/gemini-2.5-pro-preview-03-25", info: { label: "Vertex AI - Gemini 2.5 Pro Preview 03-25", versions: [], supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "no-tools" } }, configSchema: GeminiConfigSchema }); const GENERIC_GEMINI_MODEL = modelRef({ name: "vertexai/gemini", configSchema: GeminiConfigSchema, info: { label: "Google Gemini", supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true } } }); const SUPPORTED_V1_MODELS = { "gemini-1.0-pro": gemini10Pro }; const SUPPORTED_V15_MODELS = { "gemini-1.5-pro": gemini15Pro, "gemini-1.5-flash": gemini15Flash, "gemini-2.0-flash": gemini20Flash, "gemini-2.0-flash-001": gemini20Flash001, "gemini-2.0-flash-lite": gemini20FlashLite, "gemini-2.0-flash-lite-preview-02-05": gemini20FlashLitePreview0205, "gemini-2.0-pro-exp-02-05": gemini20ProExp0205, "gemini-2.5-pro-exp-03-25": gemini25ProExp0325, "gemini-2.5-pro-preview-03-25": gemini25ProPreview0325 }; const SUPPORTED_GEMINI_MODELS = { ...SUPPORTED_V1_MODELS, ...SUPPORTED_V15_MODELS }; function toGeminiRole(role, modelInfo) { switch (role) { case "user": return "user"; case "model": return "model"; case "system": if (modelInfo && modelInfo.supports?.systemRole) { throw new Error( "system role is only supported for a single message in the first position" ); } else { throw new Error("system role is not supported"); } case "tool": return "function"; default: return "user"; } } const toGeminiTool = (tool) => { const declaration = { name: tool.name.replace(/\//g, "__"), // Gemini throws on '/' in tool name description: tool.description, parameters: convertSchemaProperty(tool.inputSchema) }; return declaration; }; const toGeminiFileDataPart = (part) => { const media = part.media; if (media.url.startsWith("gs://") || media.url.startsWith("http")) { if (!media.contentType) throw new Error( "Must supply contentType when using media from http(s):// or gs:// URLs." ); return { fileData: { mimeType: media.contentType, fileUri: media.url } }; } else if (media.url.startsWith("data:")) { const dataUrl = media.url; const b64Data = dataUrl.substring(dataUrl.indexOf(",") + 1); const contentType = media.contentType || dataUrl.substring(dataUrl.indexOf(":") + 1, dataUrl.indexOf(";")); return { inlineData: { mimeType: contentType, data: b64Data } }; } throw Error( "Could not convert genkit part to gemini tool response part: missing file data" ); }; const toGeminiToolRequestPart = (part) => { if (!part?.toolRequest?.input) { throw Error( "Could not convert genkit part to gemini tool response part: missing tool request data" ); } return { functionCall: { name: part.toolRequest.name, args: part.toolRequest.input } }; }; const toGeminiToolResponsePart = (part) => { if (!part?.toolResponse?.output) { throw Error( "Could not convert genkit part to gemini tool response part: missing tool response data" ); } return { functionResponse: { name: part.toolResponse.name, response: { name: part.toolResponse.name, content: part.toolResponse.output } } }; }; function toGeminiSystemInstruction(message) { return { role: "user", parts: message.content.map(toGeminiPart) }; } function toGeminiMessage(message, modelInfo) { let sortedParts = message.content; if (message.role === "tool") { sortedParts = [...message.content].sort((a, b) => { const aRef = a.toolResponse?.ref; const bRef = b.toolResponse?.ref; if (!aRef && !bRef) return 0; if (!aRef) return 1; if (!bRef) return -1; return parseInt(aRef, 10) - parseInt(bRef, 10); }); } return { role: toGeminiRole(message.role, modelInfo), parts: sortedParts.map(toGeminiPart) }; } function fromGeminiFinishReason(reason) { if (!reason) return "unknown"; switch (reason) { case "STOP": return "stop"; case "MAX_TOKENS": return "length"; case "SAFETY": // blocked for safety case "RECITATION": return "blocked"; default: return "unknown"; } } function toGeminiPart(part) { if (part.text) { return { text: part.text }; } else if (part.media) { return toGeminiFileDataPart(part); } else if (part.toolRequest) { return toGeminiToolRequestPart(part); } else if (part.toolResponse) { return toGeminiToolResponsePart(part); } else { throw new Error("unsupported type"); } } function fromGeminiInlineDataPart(part) { if (!part.inlineData || !part.inlineData.hasOwnProperty("mimeType") || !part.inlineData.hasOwnProperty("data")) { throw new Error("Invalid GeminiPart: missing required properties"); } const { mimeType, data } = part.inlineData; const dataUrl = `data:${mimeType};base64,${data}`; return { media: { url: dataUrl, contentType: mimeType } }; } function fromGeminiFileDataPart(part) { if (!part.fileData || !part.fileData.hasOwnProperty("mimeType") || !part.fileData.hasOwnProperty("url")) { throw new Error( "Invalid Gemini File Data Part: missing required properties" ); } return { media: { url: part.fileData?.fileUri, contentType: part.fileData?.mimeType } }; } function fromGeminiFunctionCallPart(part, ref) { if (!part.functionCall) { throw new Error( "Invalid Gemini Function Call Part: missing function call data" ); } return { toolRequest: { name: part.functionCall.name, input: part.functionCall.args, ref } }; } function fromGeminiFunctionResponsePart(part, ref) { if (!part.functionResponse) { throw new Error( "Invalid Gemini Function Call Part: missing function call data" ); } return { toolResponse: { name: part.functionResponse.name.replace(/__/g, "/"), // restore slashes output: part.functionResponse.response, ref } }; } function fromGeminiPart(part, jsonMode, ref) { if (part.text !== void 0) return { text: part.text }; if (part.inlineData) return fromGeminiInlineDataPart(part); if (part.fileData) return fromGeminiFileDataPart(part); if (part.functionCall) return fromGeminiFunctionCallPart(part, ref); if (part.functionResponse) return fromGeminiFunctionResponsePart(part, ref); throw new Error( "Part type is unsupported/corrupted. Either data is missing or type cannot be inferred from type." ); } function fromGeminiCandidate(candidate, jsonMode) { const parts = candidate.content.parts || []; const genkitCandidate = { index: candidate.index || 0, message: { role: "model", content: parts.map((part, index) => { return fromGeminiPart(part, jsonMode, index.toString()); }) }, finishReason: fromGeminiFinishReason(candidate.finishReason), finishMessage: candidate.finishMessage, custom: { safetyRatings: candidate.safetyRatings, citationMetadata: candidate.citationMetadata } }; return genkitCandidate; } function convertSchemaProperty(property) { if (!property || !property.type) { return void 0; } const baseSchema = {}; if (property.description) { baseSchema.description = property.description; } if (property.enum) { baseSchema.enum = property.enum; } if (property.nullable) { baseSchema.nullable = property.nullable; } let propertyType; if (Array.isArray(property.type)) { const types = property.type; if (types.includes("null")) { baseSchema.nullable = true; } propertyType = types.find((t) => t !== "null"); } else { propertyType = property.type; } if (propertyType === "object") { const nestedProperties = {}; Object.keys(property.properties).forEach((key) => { nestedProperties[key] = convertSchemaProperty(property.properties[key]); }); return { ...baseSchema, type: FunctionDeclarationSchemaType.OBJECT, properties: nestedProperties, required: property.required }; } else if (propertyType === "array") { return { ...baseSchema, type: FunctionDeclarationSchemaType.ARRAY, items: convertSchemaProperty(property.items) }; } else { const schemaType = FunctionDeclarationSchemaType[propertyType.toUpperCase()]; if (!schemaType) { throw new GenkitError({ status: "INVALID_ARGUMENT", message: `Unsupported property type ${propertyType.toUpperCase()}` }); } return { ...baseSchema, type: schemaType }; } } function cleanSchema(schema) { const out = structuredClone(schema); for (const key in out) { if (key === "$schema" || key === "additionalProperties") { delete out[key]; continue; } if (typeof out[key] === "object") { out[key] = cleanSchema(out[key]); } if (key === "type" && Array.isArray(out[key])) { out[key] = out[key].find((t) => t !== "null"); } } return out; } function defineGeminiKnownModel(ai, name, vertexClientFactory, options, debugTraces) { const modelName = `vertexai/${name}`; const model = SUPPORTED_GEMINI_MODELS[name]; if (!model) throw new Error(`Unsupported model: ${name}`); return defineGeminiModel({ ai, modelName, version: name, modelInfo: model?.info, vertexClientFactory, options, debugTraces }); } function defineGeminiModel({ ai, modelName, version, modelInfo, vertexClientFactory, options, debugTraces }) { const middlewares = []; if (SUPPORTED_V1_MODELS[version]) { middlewares.push(simulateSystemPrompt()); } if (modelInfo?.supports?.media) { middlewares.push( downloadRequestMedia({ maxBytes: 1024 * 1024 * 20, filter: (part) => { try { const url = new URL(part.media.url); if ( // Gemini can handle these URLs ["www.youtube.com", "youtube.com", "youtu.be"].includes( url.hostname ) ) return false; } catch { } return true; } }) ); } return ai.defineModel( { name: modelName, ...modelInfo, configSchema: GeminiConfigSchema, use: middlewares }, async (request, sendChunk) => { const vertex = vertexClientFactory(request); const messages = [...request.messages]; if (messages.length === 0) throw new Error("No messages provided."); let systemInstruction = void 0; if (!SUPPORTED_V1_MODELS[version]) { const systemMessage = messages.find((m) => m.role === "system"); if (systemMessage) { messages.splice(messages.indexOf(systemMessage), 1); systemInstruction = toGeminiSystemInstruction(systemMessage); } } const tools = request.tools?.length ? [{ functionDeclarations: request.tools.map(toGeminiTool) }] : []; let toolConfig; if (request?.config?.functionCallingConfig) { toolConfig = { functionCallingConfig: { allowedFunctionNames: request.config.functionCallingConfig.allowedFunctionNames, mode: toFunctionModeEnum(request.config.functionCallingConfig.mode) } }; } else if (request.toolChoice) { toolConfig = { functionCallingConfig: { mode: toGeminiFunctionModeEnum(request.toolChoice) } }; } const jsonMode = (request.output?.format === "json" || !!request.output?.schema) && tools.length === 0; let chatRequest = { systemInstruction, tools, toolConfig, history: messages.slice(0, -1).map((message) => toGeminiMessage(message, modelInfo)), generationConfig: { candidateCount: request.candidates || void 0, temperature: request.config?.temperature, maxOutputTokens: request.config?.maxOutputTokens, topK: request.config?.topK, topP: request.config?.topP, responseMimeType: jsonMode ? "application/json" : void 0, stopSequences: request.config?.stopSequences }, safetySettings: request.config?.safetySettings }; const modelVersion = request.config?.version || version; const cacheConfigDetails = extractCacheConfig(request); const apiClient = new ApiClient( options.projectId, options.location, "v1beta1", new GoogleAuth(options.googleAuth) ); const { chatRequest: updatedChatRequest, cache } = await handleCacheIfNeeded( apiClient, request, chatRequest, modelVersion, cacheConfigDetails ); let genModel; if (jsonMode && request.output?.constrained) { updatedChatRequest.generationConfig.responseSchema = cleanSchema( request.output.schema ); } if (request.config?.googleSearchRetrieval) { updatedChatRequest.tools?.push({ googleSearchRetrieval: request.config.googleSearchRetrieval }); } if (request.config?.vertexRetrieval) { const vertexRetrieval = request.config.vertexRetrieval; const _projectId = vertexRetrieval.datastore.projectId || options.projectId; const _location = vertexRetrieval.datastore.location || options.location; const _dataStoreId = vertexRetrieval.datastore.dataStoreId; const datastore = `projects/${_projectId}/locations/${_location}/collections/default_collection/dataStores/${_dataStoreId}`; updatedChatRequest.tools?.push({ retrieval: { vertexAiSearch: { datastore }, disableAttribution: vertexRetrieval.disableAttribution } }); } const msg = toGeminiMessage(messages[messages.length - 1], modelInfo); if (cache) { genModel = vertex.preview.getGenerativeModelFromCachedContent( cache, { model: modelVersion }, { apiClient: GENKIT_CLIENT_HEADER } ); } else { genModel = vertex.preview.getGenerativeModel( { model: modelVersion }, { apiClient: GENKIT_CLIENT_HEADER } ); } const callGemini = async () => { let response; if (sendChunk) { const result = await genModel.startChat(updatedChatRequest).sendMessageStream(msg.parts); for await (const item of result.stream) { item.candidates?.forEach( (candidate) => { const c = fromGeminiCandidate(candidate, jsonMode); sendChunk({ index: c.index, content: c.message.content }); } ); } response = await result.response; } else { const result = await genModel.startChat(updatedChatRequest).sendMessage(msg.parts); response = result.response; } if (!response.candidates?.length) { throw new GenkitError({ status: "FAILED_PRECONDITION", message: "No valid candidates returned." }); } const candidateData = response.candidates.map( (c) => fromGeminiCandidate(c, jsonMode) ); return { candidates: candidateData, custom: response, usage: { ...getBasicUsageStats(request.messages, candidateData), inputTokens: response.usageMetadata?.promptTokenCount, outputTokens: response.usageMetadata?.candidatesTokenCount, totalTokens: response.usageMetadata?.totalTokenCount } }; }; return debugTraces ? await runInNewSpan( ai.registry, { metadata: { name: sendChunk ? "sendMessageStream" : "sendMessage" } }, async (metadata) => { metadata.input = { sdk: "@google-cloud/vertexai", cache, model: genModel.getModelName(), chatOptions: updatedChatRequest, parts: msg.parts, options }; const response = await callGemini(); metadata.output = response.custom; return response; } ) : await callGemini(); } ); } function toFunctionModeEnum(enumMode) { if (enumMode === void 0) { return void 0; } switch (enumMode) { case "MODE_UNSPECIFIED": { return FunctionCallingMode.MODE_UNSPECIFIED; } case "ANY": { return FunctionCallingMode.ANY; } case "AUTO": { return FunctionCallingMode.AUTO; } case "NONE": { return FunctionCallingMode.NONE; } default: throw new Error(`unsupported function calling mode: ${enumMode}`); } } function toGeminiFunctionModeEnum(genkitMode) { if (genkitMode === void 0) { return void 0; } switch (genkitMode) { case "required": { return FunctionCallingMode.ANY; } case "auto": { return FunctionCallingMode.AUTO; } case "none": { return FunctionCallingMode.NONE; } default: throw new Error(`unsupported function calling mode: ${genkitMode}`); } } export { GENERIC_GEMINI_MODEL, GeminiConfigSchema, SUPPORTED_GEMINI_MODELS, SUPPORTED_V15_MODELS, SUPPORTED_V1_MODELS, cleanSchema, defineGeminiKnownModel, defineGeminiModel, fromGeminiCandidate, gemini, gemini10Pro, gemini15Flash, gemini15Pro, gemini20Flash, gemini20Flash001, gemini20FlashLite, gemini20FlashLitePreview0205, gemini20ProExp0205, gemini25ProExp0325, gemini25ProPreview0325, toGeminiMessage, toGeminiSystemInstruction, toGeminiTool }; //# sourceMappingURL=gemini.mjs.map