genkitx-cloud-sql-pg
Version:
Genkit AI framework plugin for Cloud SQL for PostgreSQL.
293 lines • 10.9 kB
JavaScript
import { z } from "genkit";
import { genkitPlugin } from "genkit/plugin";
import {
CommonRetrieverOptionsSchema,
Document,
indexerRef,
retrieverRef
} from "genkit/retriever";
import { v4 as uuidv4 } from "uuid";
import { DistanceStrategy } from "./indexes.js";
import { Column, PostgresEngine as PostgresEngine2 } from "./engine.js";
import {
DistanceStrategy as DistanceStrategy2,
ExactNearestNeighbor,
HNSWIndex,
HNSWQueryOptions,
IVFFlatIndex,
IVFFlatQueryOptions
} from "./indexes.js";
const PostgresRetrieverOptionsSchema = CommonRetrieverOptionsSchema.extend({
k: z.number().max(1e3),
filter: z.string().optional()
});
const PostgresIndexerOptionsSchema = z.object({
batchSize: z.number().default(100)
});
const postgresRetrieverRef = (params) => {
return retrieverRef({
name: `postgres/${params.tableName}`,
info: {
label: params.tableName ?? `Postgres - ${params.tableName}`
},
configSchema: PostgresRetrieverOptionsSchema
});
};
const postgresIndexerRef = (params) => {
return indexerRef({
name: `postgres/${params.tableName}`,
info: {
label: params.tableName ?? `Postgres - ${params.tableName}`
},
configSchema: PostgresIndexerOptionsSchema.optional()
});
};
function postgres(params) {
return genkitPlugin("postgres", async (ai) => {
params.map((i) => configurePostgresRetriever(ai, i));
params.map((i) => configurePostgresIndexer(ai, i));
});
}
var index_default = postgres;
async function configurePostgresRetriever(ai, params) {
const schemaName = params.schemaName ?? "public";
const contentColumn = params.contentColumn ?? "content";
const embeddingColumn = params.embeddingColumn ?? "embedding";
const distanceStrategy = params.distanceStrategy ?? DistanceStrategy.COSINE_DISTANCE;
if (!params.engine) {
throw new Error("Engine is required");
}
async function checkColumns() {
if (params.metadataColumns !== void 0 && params.ignoreMetadataColumns !== void 0) {
throw "Can not use both metadata_columns and ignore_metadata_columns.";
}
const { rows } = await params.engine.pool.raw(
`SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '${params.tableName}' AND table_schema = '${schemaName}'`
);
const columns = {};
for (const index in rows) {
const row = rows[index];
columns[row["column_name"]] = row["data_type"];
}
if (params.idColumn && !columns.hasOwnProperty(params.idColumn)) {
throw `Id column: ${params.idColumn}, does not exist.`;
}
if (contentColumn && !columns.hasOwnProperty(contentColumn)) {
throw `Content column: ${params.contentColumn}, does not exist.`;
}
const contentType = contentColumn ? columns[contentColumn] : "";
if (contentType !== "text" && !contentType.includes("char")) {
throw `Content column: ${params.contentColumn}, is type: ${contentType}. It must be a type of character string.`;
}
if (embeddingColumn && !columns.hasOwnProperty(embeddingColumn)) {
throw `Embedding column: ${embeddingColumn}, does not exist.`;
}
if (embeddingColumn && columns[embeddingColumn] !== "USER-DEFINED") {
throw `Embedding column: ${embeddingColumn} is not of type Vector.`;
}
const metadataJsonColumnToCheck = params.metadataJsonColumn ?? "";
params.metadataJsonColumn = columns.hasOwnProperty(
metadataJsonColumnToCheck
) ? params.metadataJsonColumn : "";
if (params.metadataColumns) {
for (const column of params.metadataColumns) {
if (column && !columns.hasOwnProperty(column)) {
throw `Metadata column: ${column}, does not exist.`;
}
}
}
const allColumns = columns;
if (params.ignoreMetadataColumns !== void 0 && params.ignoreMetadataColumns.length > 0) {
for (const column of params.ignoreMetadataColumns) {
delete allColumns[column];
}
if (params.idColumn) {
delete allColumns[params.idColumn];
}
if (contentColumn) {
delete allColumns[contentColumn];
}
if (embeddingColumn) {
delete allColumns[embeddingColumn];
}
params.metadataColumns = Object.keys(allColumns);
}
}
async function queryCollection(embedding, k, filter) {
k = k ?? 4;
const operator = distanceStrategy.operator;
const searchFunction = distanceStrategy.searchFunction;
const _filter = filter !== void 0 ? `WHERE ${filter}` : "";
const metadataColNames = params.metadataColumns && params.metadataColumns.length > 0 ? `"${params.metadataColumns.join('","')}"` : "";
const metadataJsonColName = params.metadataJsonColumn ? `, "${params.metadataJsonColumn}"` : "";
const query = `SELECT "${params.idColumn}", "${contentColumn}", "${embeddingColumn}", ${metadataColNames} ${metadataJsonColName}, ${searchFunction}("${embeddingColumn}", '[${embedding}]') as distance FROM "${schemaName}"."${params.tableName}" ${_filter} ORDER BY "${embeddingColumn}" ${operator} '[${embedding}]' LIMIT ${k};`;
if (params.indexQueryOptions) {
await params.engine.pool.raw(
`SET LOCAL ${params.indexQueryOptions.to_string()}`
);
}
const { rows } = await params.engine.pool.raw(query);
return rows;
}
return ai.defineRetriever(
{
name: `postgres/${params.tableName}`,
configSchema: PostgresRetrieverOptionsSchema
},
async (content, options) => {
console.log(`Retrieving data for table: ${params.tableName}`);
checkColumns();
const queryEmbeddings = await ai.embed({
embedder: params.embedder,
content,
options: params.embedderOptions
});
const embedding = queryEmbeddings[0].embedding;
const results = await queryCollection(
embedding,
options.k,
options.filter
);
const documents = [];
for (const row of results) {
const metadata = params.metadataJsonColumn && row[params.metadataJsonColumn] ? row[params.metadataJsonColumn] : {};
if (params.metadataColumns) {
for (const col of params.metadataColumns) {
metadata[col] = row[col];
}
}
documents.push(
new Document({
content: row[contentColumn],
metadata
})
);
}
return { documents };
}
);
}
function configurePostgresIndexer(ai, params) {
const schemaName = params.schemaName ?? "public";
const contentColumn = params.contentColumn ?? "content";
const embeddingColumn = params.embeddingColumn ?? "embedding";
const idColumn = params.idColumn ?? "id";
const metadataJsonColumn = params.metadataJsonColumn ?? "metadata";
if (!params.engine) {
throw new Error("Engine is required");
}
if (params.metadataColumns && params.ignoreMetadataColumns) {
throw new Error(
"Cannot use both metadataColumns and ignoreMetadataColumns"
);
}
async function checkColumns() {
const { rows } = await params.engine.pool.raw(
`SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '${params.tableName}' AND table_schema = '${schemaName}'`
);
const columns = {};
for (const index in rows) {
const row = rows[index];
columns[row["column_name"]] = row["data_type"];
}
if (!columns.hasOwnProperty(idColumn)) {
throw new Error(`Id column: ${idColumn}, does not exist.`);
}
if (!columns.hasOwnProperty(contentColumn)) {
throw new Error(`Content column: ${contentColumn}, does not exist.`);
}
if (!columns.hasOwnProperty(embeddingColumn)) {
throw new Error(`Embedding column: ${embeddingColumn}, does not exist.`);
}
if (columns[embeddingColumn] !== "USER-DEFINED") {
throw new Error(
`Embedding column: ${embeddingColumn} is not of type Vector.`
);
}
if (params.metadataColumns) {
for (const column of params.metadataColumns) {
if (column && !columns.hasOwnProperty(column)) {
throw new Error(`Metadata column: ${column}, does not exist.`);
}
}
}
}
return ai.defineIndexer(
{
name: `postgres/${params.tableName}`,
configSchema: PostgresIndexerOptionsSchema.optional()
},
async (docs, options) => {
try {
await checkColumns();
const documents = Array.isArray(docs) ? docs : docs.documents || [];
const mergedOptions = Array.isArray(docs) ? options : docs.options || options || {};
const batchSize = mergedOptions.batchSize || 100;
console.log(
`Indexing ${documents.length} documents in batches of ${batchSize}`
);
for (let i = 0; i < documents.length; i += batchSize) {
const chunk = documents.slice(i, i + batchSize);
const texts = chunk.map(
(doc) => Array.isArray(doc.content) ? doc.content.map((c) => c.text).join(" ") : doc.content
);
let embeddings;
try {
if (ai.embedMany) {
embeddings = await ai.embedMany({
embedder: params.embedder,
content: texts,
options: params.embedderOptions
});
} else {
embeddings = await Promise.all(
texts.map(
(text) => ai.embed({
embedder: params.embedder,
content: text,
options: params.embedderOptions
}).then((res) => res[0])
)
);
}
} catch (error) {
throw new Error("Embedding failed", { cause: error });
}
const insertData = chunk.map((doc, index) => ({
[idColumn]: doc.metadata?.[idColumn] || uuidv4(),
[contentColumn]: texts[index],
[embeddingColumn]: JSON.stringify(embeddings[index].embedding),
...metadataJsonColumn && {
[metadataJsonColumn]: doc.metadata || {}
},
...Object.fromEntries(
(params.metadataColumns || []).filter((col) => doc.metadata?.[col] !== void 0).map((col) => [col, doc.metadata?.[col]])
)
}));
const table = schemaName ? params.engine.pool.withSchema(schemaName).table(params.tableName) : params.engine.pool.table(params.tableName);
await table.insert(insertData);
}
} catch (error) {
console.error("Error in indexer:", error);
throw error;
}
}
);
}
export {
Column,
DistanceStrategy2 as DistanceStrategy,
ExactNearestNeighbor,
HNSWIndex,
HNSWQueryOptions,
IVFFlatIndex,
IVFFlatQueryOptions,
PostgresEngine2 as PostgresEngine,
configurePostgresIndexer,
configurePostgresRetriever,
index_default as default,
postgres,
postgresIndexerRef,
postgresRetrieverRef
};
//# sourceMappingURL=index.mjs.map