@genkit-ai/vertexai
Version:
Genkit AI framework plugin for Google Cloud Vertex AI APIs including Gemini APIs, Imagen, and more.
94 lines • 3.32 kB
JavaScript
import {
retrieverRef
} from "genkit";
import { queryPublicEndpoint } from "./query_public_endpoint.mjs";
import {
VertexAIVectorRetrieverOptionsSchema
} from "./types.mjs";
import { getProjectNumber } from "./utils.mjs";
const DEFAULT_K = 10;
function vertexAiRetrievers(ai, params) {
const vectorSearchOptions = params.pluginOptions.vectorSearchOptions;
const defaultEmbedder = params.defaultEmbedder;
const retrievers = [];
if (!vectorSearchOptions || vectorSearchOptions.length === 0) {
return retrievers;
}
for (const vectorSearchOption of vectorSearchOptions) {
const { documentRetriever, indexId, publicDomainName } = vectorSearchOption;
const embedderOptions = vectorSearchOption.embedderOptions;
const retriever = ai.defineRetriever(
{
name: `vertexai/${indexId}`,
configSchema: VertexAIVectorRetrieverOptionsSchema.optional()
},
async (content, options) => {
const embedderReference = vectorSearchOption.embedder ?? defaultEmbedder;
if (!embedderReference) {
throw new Error(
"Embedder reference is required to define Vertex AI retriever"
);
}
const queryEmbedding = (await ai.embed({
embedder: embedderReference,
options: embedderOptions,
content
}))[0].embedding;
const accessToken = await params.authClient.getAccessToken();
if (!accessToken) {
throw new Error(
"Error generating access token when defining Vertex AI retriever"
);
}
const projectId = params.pluginOptions.projectId;
if (!projectId) {
throw new Error(
"Project ID is required to define Vertex AI retriever"
);
}
const projectNumber = await getProjectNumber(projectId);
const location = params.pluginOptions.location;
if (!location) {
throw new Error("Location is required to define Vertex AI retriever");
}
const res = await queryPublicEndpoint({
featureVector: queryEmbedding,
neighborCount: options?.k || DEFAULT_K,
accessToken,
projectId,
location,
publicDomainName,
projectNumber,
indexEndpointId: vectorSearchOption.indexEndpointId,
deployedIndexId: vectorSearchOption.deployedIndexId,
restricts: content.metadata?.restricts,
numericRestricts: content.metadata?.numericRestricts
});
const nearestNeighbors = res.nearestNeighbors;
const queryRes = nearestNeighbors ? nearestNeighbors[0] : null;
const neighbors = queryRes ? queryRes.neighbors : null;
if (!neighbors) {
return { documents: [] };
}
const documents = await documentRetriever(neighbors, options);
return { documents };
}
);
retrievers.push(retriever);
}
return retrievers;
}
const vertexAiRetrieverRef = (params) => {
return retrieverRef({
name: `vertexai/${params.indexId}`,
info: {
label: params.displayName ?? `ertex AI - ${params.indexId}`
},
configSchema: VertexAIVectorRetrieverOptionsSchema.optional()
});
};
export {
vertexAiRetrieverRef,
vertexAiRetrievers
};
//# sourceMappingURL=retrievers.mjs.map