UNPKG

genkitx-cloud-sql-pg

Version:

Genkit AI framework plugin for Cloud SQL for PostgreSQL.

293 lines 10.9 kB
import { z } from "genkit"; import { genkitPlugin } from "genkit/plugin"; import { CommonRetrieverOptionsSchema, Document, indexerRef, retrieverRef } from "genkit/retriever"; import { v4 as uuidv4 } from "uuid"; import { DistanceStrategy } from "./indexes.js"; import { Column, PostgresEngine as PostgresEngine2 } from "./engine.js"; import { DistanceStrategy as DistanceStrategy2, ExactNearestNeighbor, HNSWIndex, HNSWQueryOptions, IVFFlatIndex, IVFFlatQueryOptions } from "./indexes.js"; const PostgresRetrieverOptionsSchema = CommonRetrieverOptionsSchema.extend({ k: z.number().max(1e3), filter: z.string().optional() }); const PostgresIndexerOptionsSchema = z.object({ batchSize: z.number().default(100) }); const postgresRetrieverRef = (params) => { return retrieverRef({ name: `postgres/${params.tableName}`, info: { label: params.tableName ?? `Postgres - ${params.tableName}` }, configSchema: PostgresRetrieverOptionsSchema }); }; const postgresIndexerRef = (params) => { return indexerRef({ name: `postgres/${params.tableName}`, info: { label: params.tableName ?? `Postgres - ${params.tableName}` }, configSchema: PostgresIndexerOptionsSchema.optional() }); }; function postgres(params) { return genkitPlugin("postgres", async (ai) => { params.map((i) => configurePostgresRetriever(ai, i)); params.map((i) => configurePostgresIndexer(ai, i)); }); } var index_default = postgres; async function configurePostgresRetriever(ai, params) { const schemaName = params.schemaName ?? "public"; const contentColumn = params.contentColumn ?? "content"; const embeddingColumn = params.embeddingColumn ?? "embedding"; const distanceStrategy = params.distanceStrategy ?? DistanceStrategy.COSINE_DISTANCE; if (!params.engine) { throw new Error("Engine is required"); } async function checkColumns() { if (params.metadataColumns !== void 0 && params.ignoreMetadataColumns !== void 0) { throw "Can not use both metadata_columns and ignore_metadata_columns."; } const { rows } = await params.engine.pool.raw( `SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '${params.tableName}' AND table_schema = '${schemaName}'` ); const columns = {}; for (const index in rows) { const row = rows[index]; columns[row["column_name"]] = row["data_type"]; } if (params.idColumn && !columns.hasOwnProperty(params.idColumn)) { throw `Id column: ${params.idColumn}, does not exist.`; } if (contentColumn && !columns.hasOwnProperty(contentColumn)) { throw `Content column: ${params.contentColumn}, does not exist.`; } const contentType = contentColumn ? columns[contentColumn] : ""; if (contentType !== "text" && !contentType.includes("char")) { throw `Content column: ${params.contentColumn}, is type: ${contentType}. It must be a type of character string.`; } if (embeddingColumn && !columns.hasOwnProperty(embeddingColumn)) { throw `Embedding column: ${embeddingColumn}, does not exist.`; } if (embeddingColumn && columns[embeddingColumn] !== "USER-DEFINED") { throw `Embedding column: ${embeddingColumn} is not of type Vector.`; } const metadataJsonColumnToCheck = params.metadataJsonColumn ?? ""; params.metadataJsonColumn = columns.hasOwnProperty( metadataJsonColumnToCheck ) ? params.metadataJsonColumn : ""; if (params.metadataColumns) { for (const column of params.metadataColumns) { if (column && !columns.hasOwnProperty(column)) { throw `Metadata column: ${column}, does not exist.`; } } } const allColumns = columns; if (params.ignoreMetadataColumns !== void 0 && params.ignoreMetadataColumns.length > 0) { for (const column of params.ignoreMetadataColumns) { delete allColumns[column]; } if (params.idColumn) { delete allColumns[params.idColumn]; } if (contentColumn) { delete allColumns[contentColumn]; } if (embeddingColumn) { delete allColumns[embeddingColumn]; } params.metadataColumns = Object.keys(allColumns); } } async function queryCollection(embedding, k, filter) { k = k ?? 4; const operator = distanceStrategy.operator; const searchFunction = distanceStrategy.searchFunction; const _filter = filter !== void 0 ? `WHERE ${filter}` : ""; const metadataColNames = params.metadataColumns && params.metadataColumns.length > 0 ? `"${params.metadataColumns.join('","')}"` : ""; const metadataJsonColName = params.metadataJsonColumn ? `, "${params.metadataJsonColumn}"` : ""; const query = `SELECT "${params.idColumn}", "${contentColumn}", "${embeddingColumn}", ${metadataColNames} ${metadataJsonColName}, ${searchFunction}("${embeddingColumn}", '[${embedding}]') as distance FROM "${schemaName}"."${params.tableName}" ${_filter} ORDER BY "${embeddingColumn}" ${operator} '[${embedding}]' LIMIT ${k};`; if (params.indexQueryOptions) { await params.engine.pool.raw( `SET LOCAL ${params.indexQueryOptions.to_string()}` ); } const { rows } = await params.engine.pool.raw(query); return rows; } return ai.defineRetriever( { name: `postgres/${params.tableName}`, configSchema: PostgresRetrieverOptionsSchema }, async (content, options) => { console.log(`Retrieving data for table: ${params.tableName}`); checkColumns(); const queryEmbeddings = await ai.embed({ embedder: params.embedder, content, options: params.embedderOptions }); const embedding = queryEmbeddings[0].embedding; const results = await queryCollection( embedding, options.k, options.filter ); const documents = []; for (const row of results) { const metadata = params.metadataJsonColumn && row[params.metadataJsonColumn] ? row[params.metadataJsonColumn] : {}; if (params.metadataColumns) { for (const col of params.metadataColumns) { metadata[col] = row[col]; } } documents.push( new Document({ content: row[contentColumn], metadata }) ); } return { documents }; } ); } function configurePostgresIndexer(ai, params) { const schemaName = params.schemaName ?? "public"; const contentColumn = params.contentColumn ?? "content"; const embeddingColumn = params.embeddingColumn ?? "embedding"; const idColumn = params.idColumn ?? "id"; const metadataJsonColumn = params.metadataJsonColumn ?? "metadata"; if (!params.engine) { throw new Error("Engine is required"); } if (params.metadataColumns && params.ignoreMetadataColumns) { throw new Error( "Cannot use both metadataColumns and ignoreMetadataColumns" ); } async function checkColumns() { const { rows } = await params.engine.pool.raw( `SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '${params.tableName}' AND table_schema = '${schemaName}'` ); const columns = {}; for (const index in rows) { const row = rows[index]; columns[row["column_name"]] = row["data_type"]; } if (!columns.hasOwnProperty(idColumn)) { throw new Error(`Id column: ${idColumn}, does not exist.`); } if (!columns.hasOwnProperty(contentColumn)) { throw new Error(`Content column: ${contentColumn}, does not exist.`); } if (!columns.hasOwnProperty(embeddingColumn)) { throw new Error(`Embedding column: ${embeddingColumn}, does not exist.`); } if (columns[embeddingColumn] !== "USER-DEFINED") { throw new Error( `Embedding column: ${embeddingColumn} is not of type Vector.` ); } if (params.metadataColumns) { for (const column of params.metadataColumns) { if (column && !columns.hasOwnProperty(column)) { throw new Error(`Metadata column: ${column}, does not exist.`); } } } } return ai.defineIndexer( { name: `postgres/${params.tableName}`, configSchema: PostgresIndexerOptionsSchema.optional() }, async (docs, options) => { try { await checkColumns(); const documents = Array.isArray(docs) ? docs : docs.documents || []; const mergedOptions = Array.isArray(docs) ? options : docs.options || options || {}; const batchSize = mergedOptions.batchSize || 100; console.log( `Indexing ${documents.length} documents in batches of ${batchSize}` ); for (let i = 0; i < documents.length; i += batchSize) { const chunk = documents.slice(i, i + batchSize); const texts = chunk.map( (doc) => Array.isArray(doc.content) ? doc.content.map((c) => c.text).join(" ") : doc.content ); let embeddings; try { if (ai.embedMany) { embeddings = await ai.embedMany({ embedder: params.embedder, content: texts, options: params.embedderOptions }); } else { embeddings = await Promise.all( texts.map( (text) => ai.embed({ embedder: params.embedder, content: text, options: params.embedderOptions }).then((res) => res[0]) ) ); } } catch (error) { throw new Error("Embedding failed", { cause: error }); } const insertData = chunk.map((doc, index) => ({ [idColumn]: doc.metadata?.[idColumn] || uuidv4(), [contentColumn]: texts[index], [embeddingColumn]: JSON.stringify(embeddings[index].embedding), ...metadataJsonColumn && { [metadataJsonColumn]: doc.metadata || {} }, ...Object.fromEntries( (params.metadataColumns || []).filter((col) => doc.metadata?.[col] !== void 0).map((col) => [col, doc.metadata?.[col]]) ) })); const table = schemaName ? params.engine.pool.withSchema(schemaName).table(params.tableName) : params.engine.pool.table(params.tableName); await table.insert(insertData); } } catch (error) { console.error("Error in indexer:", error); throw error; } } ); } export { Column, DistanceStrategy2 as DistanceStrategy, ExactNearestNeighbor, HNSWIndex, HNSWQueryOptions, IVFFlatIndex, IVFFlatQueryOptions, PostgresEngine2 as PostgresEngine, configurePostgresIndexer, configurePostgresRetriever, index_default as default, postgres, postgresIndexerRef, postgresRetrieverRef }; //# sourceMappingURL=index.mjs.map