UNPKG

genkitx-qdrant

Version:

Genkit AI framework plugin for the Qdrant vector database.

285 lines (268 loc) 8.69 kB
import { EmbedderArgument } from '@genkit-ai/ai/embedder'; import { CommonRetrieverOptionsSchema, Document, indexerRef, retrieverRef, } from '@genkit-ai/ai/retriever'; import type { QdrantClientParams, Schemas } from '@qdrant/js-client-rest'; import { QdrantClient } from '@qdrant/js-client-rest'; import { z, type Genkit } from 'genkit'; import { genkitPlugin } from 'genkit/plugin'; import { v5 as uuidv5 } from 'uuid'; const FilterType: z.ZodType<Schemas['Filter']> = z.any(); const QdrantRetrieverOptionsSchema = CommonRetrieverOptionsSchema.extend({ k: z.number().default(10), filter: FilterType.optional(), }); export const QdrantIndexerOptionsSchema = z.null().optional(); const CONTENT_PAYLOAD_KEY = 'content'; const METADATA_PAYLOAD_KEY = 'metadata'; const CONTENT_TYPE_KEY = '_content_type'; /** * Parameters for the Qdrant plugin. */ interface QdrantPluginParams<E extends z.ZodTypeAny = z.ZodTypeAny> { /** * Parameters for instantiating `QdrantClient`. */ clientParams: QdrantClientParams; /** * Name of the Qdrant collection. */ collectionName: string; /** * Embedder to use for the retriever and indexer. */ embedder: EmbedderArgument<E>; /** * Addtional options for the embedder. */ embedderOptions?: z.infer<E>; /** * Document content key in the Qdrant payload. * Default is 'content'. */ contentPayloadKey?: string; /** * Document metadata key in the Qdrant payload. * Default is 'metadata'. */ metadataPayloadKey?: string; /** * Document data type key in the Qdrant payload. * Default is '_content_type'. * This is used to store the type of content. */ dataTypePayloadKey?: string; /** * Additional options when creating a collection. */ collectionCreateOptions?: Schemas['CreateCollection']; } /** * qdrantRetrieverRef function creates a retriever for Qdrant. * @param params The params for the new Qdrant retriever * @param params.collectionName The collection name for the Qdrant retriever * @param params.displayName A display name for the retriever. If not specified, the default label will be `Qdrant - <collectionName>` * @returns A reference to a Qdrant retriever. */ export const qdrantRetrieverRef = (collectionName: string, displayName: string | null = null) => { return retrieverRef({ name: `qdrant/${collectionName}`, info: { label: displayName ?? `Qdrant - ${collectionName}`, }, configSchema: QdrantRetrieverOptionsSchema, }); }; /** * qdrantIndexerRef function creates an indexer for Qdrant. * @param params The params for the new Qdrant indexer. * @param params.collectionName The collection name for the Qdrant indexer. * @param params.displayName A display name for the indexer. If not specified, the default label will be `Qdrant - <collectionName>` * @returns A reference to a Qdrant indexer. */ export const qdrantIndexerRef = (collectionName: string, displayName: string | null = null) => { return indexerRef({ name: `qdrant/${collectionName}`, info: { label: displayName ?? `Qdrant - ${collectionName}`, }, configSchema: QdrantIndexerOptionsSchema, }); }; /** * Qdrant plugin that provides the Qdrant retriever * and indexer */ export function qdrant<EmbedderCustomOptions extends z.ZodTypeAny>( params: QdrantPluginParams<EmbedderCustomOptions>[], ) { return genkitPlugin('qdrant', async (ai) => { params.forEach((p) => configureQdrantRetriever(ai, p)); params.forEach((p) => configureQdrantIndexer(ai, p)); }); } export default qdrant; export function configureQdrantRetriever< EmbedderCustomOptions extends z.ZodTypeAny, >(ai: Genkit, params: QdrantPluginParams<EmbedderCustomOptions>) { const { embedder, collectionName, embedderOptions, clientParams, contentPayloadKey, metadataPayloadKey, } = params; const client = new QdrantClient(clientParams); const contentKey = contentPayloadKey ?? CONTENT_PAYLOAD_KEY; const metadataKey = metadataPayloadKey ?? METADATA_PAYLOAD_KEY; const dataTypeKey = params.dataTypePayloadKey ?? CONTENT_TYPE_KEY; return ai.defineRetriever( { name: `qdrant/${collectionName}`, configSchema: QdrantRetrieverOptionsSchema, }, async (content, options) => { await ensureCollection(params, false, ai); const queryEmbeddings = await ai.embed({ embedder, content, options: embedderOptions, }); const results = ( await client.query(collectionName, { query: queryEmbeddings[0].embedding, limit: options.k, filter: options.filter, with_payload: [contentKey, metadataKey, dataTypeKey], with_vector: false, }) ).points; const documents = results.map((result) => { const content = result.payload?.[contentKey] ?? ''; const metadata = result.payload?.[metadataKey] ?? {}; const dataType = result.payload?.[dataTypeKey] ?? 'text'; return Document.fromData( content as string, dataType as string, metadata as Record<string, unknown>, ).toJSON(); }); return { documents, }; }, ); } export function configureQdrantIndexer< EmbedderCustomOptions extends z.ZodTypeAny, >(ai: Genkit, params: QdrantPluginParams<EmbedderCustomOptions>) { const { embedder, collectionName, embedderOptions, clientParams, contentPayloadKey, metadataPayloadKey, } = params; const client = new QdrantClient(clientParams); const contentKey = contentPayloadKey ?? CONTENT_PAYLOAD_KEY; const metadataKey = metadataPayloadKey ?? METADATA_PAYLOAD_KEY; const dataTypeKey = params.dataTypePayloadKey ?? CONTENT_TYPE_KEY; return ai.defineIndexer( { name: `qdrant/${collectionName}`, configSchema: QdrantIndexerOptionsSchema, }, async (docs, options) => { await ensureCollection(params, true, ai); const embeddings = await Promise.all( docs.map((doc) => ai.embed({ embedder, content: doc, options: embedderOptions, }), ), ); const points = embeddings .map((embeddingArr, i) => { const doc = docs[i]; const embeddingDocs = doc.getEmbeddingDocuments(embeddingArr); return embeddingArr.map((docEmbedding, j) => { const embeddingDoc = embeddingDocs[j] || {}; const id = uuidv5(JSON.stringify(embeddingDoc), uuidv5.URL); return { id, vector: docEmbedding.embedding, payload: { [contentKey]: embeddingDoc.data, [metadataKey]: embeddingDoc.metadata, [dataTypeKey]: embeddingDoc.dataType, }, }; }); }) .reduce((acc, val) => acc.concat(val), []); await client.upsert(collectionName, { points }); }, ); } /** * Helper function for creating a Qdrant collection. */ export async function createQdrantCollection< EmbedderCustomOptions extends z.ZodTypeAny, >(params: QdrantPluginParams<EmbedderCustomOptions>, ai) { const { embedder, embedderOptions, clientParams, collectionName } = params; const client = new QdrantClient(clientParams); let collectionCreateOptions = params.collectionCreateOptions; if (!collectionCreateOptions) { const embeddings = await ai.embed({ embedder, content: 'SOME_TEXT', options: embedderOptions, }); const vector = Array.isArray(embeddings) ? embeddings[0].embedding : embeddings.embedding; collectionCreateOptions = { vectors: { size: vector.length, distance: 'Cosine', }, }; } return await client.createCollection(collectionName, collectionCreateOptions); } /** * Helper function for deleting Qdrant collections. */ export async function deleteQdrantCollection(params: QdrantPluginParams) { const client = new QdrantClient(params.clientParams); return await client.deleteCollection(params.collectionName); } /** * Private helper for ensuring that a Qdrant collection exists. */ async function ensureCollection( params: QdrantPluginParams, createCollection = true, ai?, ) { const { clientParams, collectionName } = params; const client = new QdrantClient(clientParams); if ((await client.collectionExists(collectionName)).exists) { return; } if (createCollection) { await createQdrantCollection(params, ai); } else { throw new Error( `Collection ${collectionName} does not exist. Index some documents first.`, ); } }