UNPKG

@genkit-ai/vertexai

Version:

Genkit AI framework plugin for Google Cloud Vertex AI APIs including Gemini APIs, Imagen, and more.

92 lines 3.3 kB
import { retrieverRef } from "genkit"; import { queryPublicEndpoint } from "./query_public_endpoint"; import { VertexAIVectorRetrieverOptionsSchema } from "./types"; import { getProjectNumber } from "./utils"; const DEFAULT_K = 10; function vertexAiRetrievers(ai, params) { const vectorSearchOptions = params.pluginOptions.vectorSearchOptions; const defaultEmbedder = params.defaultEmbedder; const retrievers = []; if (!vectorSearchOptions || vectorSearchOptions.length === 0) { return retrievers; } for (const vectorSearchOption of vectorSearchOptions) { const { documentRetriever, indexId, publicDomainName } = vectorSearchOption; const embedderOptions = vectorSearchOption.embedderOptions; const retriever = ai.defineRetriever( { name: `vertexai/${indexId}`, configSchema: VertexAIVectorRetrieverOptionsSchema.optional() }, async (content, options) => { const embedderReference = vectorSearchOption.embedder ?? defaultEmbedder; if (!embedderReference) { throw new Error( "Embedder reference is required to define Vertex AI retriever" ); } const queryEmbedding = (await ai.embed({ embedder: embedderReference, options: embedderOptions, content }))[0].embedding; const accessToken = await params.authClient.getAccessToken(); if (!accessToken) { throw new Error( "Error generating access token when defining Vertex AI retriever" ); } const projectId = params.pluginOptions.projectId; if (!projectId) { throw new Error( "Project ID is required to define Vertex AI retriever" ); } const projectNumber = await getProjectNumber(projectId); const location = params.pluginOptions.location; if (!location) { throw new Error("Location is required to define Vertex AI retriever"); } let res = await queryPublicEndpoint({ featureVector: queryEmbedding, neighborCount: options?.k || DEFAULT_K, accessToken, projectId, location, publicDomainName, projectNumber, indexEndpointId: vectorSearchOption.indexEndpointId, deployedIndexId: vectorSearchOption.deployedIndexId, restricts: content.metadata?.restricts, numericRestricts: content.metadata?.numericRestricts }); const nearestNeighbors = res.nearestNeighbors; const queryRes = nearestNeighbors ? nearestNeighbors[0] : null; const neighbors = queryRes ? queryRes.neighbors : null; if (!neighbors) { return { documents: [] }; } const documents = await documentRetriever(neighbors, options); return { documents }; } ); retrievers.push(retriever); } return retrievers; } const vertexAiRetrieverRef = (params) => { return retrieverRef({ name: `vertexai/${params.indexId}`, info: { label: params.displayName ?? `ertex AI - ${params.indexId}` }, configSchema: VertexAIVectorRetrieverOptionsSchema.optional() }); }; export { vertexAiRetrieverRef, vertexAiRetrievers }; //# sourceMappingURL=retrievers.mjs.map