genkitx-astra-db
Version:
An Astra DB indexer and retriever for Genkit
235 lines (213 loc) • 7.31 kB
text/typescript
// Copyright DataStax, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { EmbedderArgument, Embedding } from "genkit/embedder";
import {
CommonRetrieverOptionsSchema,
indexerRef,
retrieverRef,
} from "genkit/retriever";
import { GenkitPlugin, genkitPlugin } from "genkit/plugin";
import { Genkit, GenkitError, z } from "genkit";
import { Md5 } from "ts-md5";
import { DataAPIClient, Filter, SomeDoc, Sort } from "@datastax/astra-db-ts";
type AstraDBClientOptions = {
applicationToken: string;
apiEndpoint: string;
keyspace?: string;
};
const PLUGIN_NAME = "astradb";
const DEFAULT_KEYSPACE = "default_keyspace";
const createAstraDBRetrieverOptionsSchema = <Schema extends SomeDoc>() =>
CommonRetrieverOptionsSchema.extend({
filter: z.custom<Filter<Schema>>().optional(),
});
const AstraDBIndexerOptionsSchema = z.object({});
export const astraDBRetrieverRef = <Schema extends SomeDoc>(params: {
collectionName: string;
displayName?: string;
}) => {
return retrieverRef({
name: `${PLUGIN_NAME}/${params.collectionName}`,
info: {
label: params.displayName ?? `Astra DB - ${params.collectionName}`,
},
configSchema: createAstraDBRetrieverOptionsSchema<Schema>(),
});
};
export const astraDBIndexerRef = (params: {
collectionName: string;
displayName?: string;
}) => {
return indexerRef({
name: `${PLUGIN_NAME}/${params.collectionName}`,
info: {
label: params.displayName ?? `Astra DB - ${params.collectionName}`,
},
configSchema: AstraDBIndexerOptionsSchema,
});
};
export function astraDB<EmbedderCustomOptions extends z.ZodTypeAny>(
params: {
clientParams?: AstraDBClientOptions;
collectionName: string;
embedder?: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}[]
): GenkitPlugin {
return genkitPlugin(PLUGIN_NAME, async (ai: Genkit) => {
params.forEach((i) => configureAstraDBRetriever(ai, i));
params.forEach((i) => configureAstraDBIndexer(ai, i));
});
}
export function configureAstraDBRetriever<
Schema extends SomeDoc,
EmbedderCustomOptions extends z.ZodTypeAny
>(
ai: Genkit,
params: {
clientParams?: AstraDBClientOptions;
collectionName: string;
embedder?: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}
) {
const { collectionName, embedder, embedderOptions } = params;
const { applicationToken, apiEndpoint } =
params.clientParams ?? getDefaultConfig();
const keyspace = params.clientParams?.keyspace ?? DEFAULT_KEYSPACE;
const client = new DataAPIClient(applicationToken);
const db = client.db(apiEndpoint, { keyspace });
const collection = db.collection<Schema>(collectionName);
return ai.defineRetriever(
{
name: `${PLUGIN_NAME}/${collectionName}`,
configSchema: createAstraDBRetrieverOptionsSchema<Schema>().optional(),
},
async (content, options) => {
let queryEmbeddings: Embedding[] = [];
if (embedder) {
queryEmbeddings = await ai.embed({
embedder,
content,
options: embedderOptions,
});
}
const filter = options?.filter ?? {};
const limit = options?.k ?? 5;
const sort: Sort =
queryEmbeddings.length > 0
? { $vector: queryEmbeddings[0].embedding }
: { $vectorize: content.text };
const cursor = collection.find(filter, { sort, limit });
const results = await cursor.toArray();
const documents = results.map((result) => {
const { text, metadata } = result;
return { content: [{ text }], metadata };
});
return { documents };
}
);
}
export function configureAstraDBIndexer<
EmbedderCustomOptions extends z.ZodTypeAny
>(
ai: Genkit,
params: {
clientParams?: AstraDBClientOptions;
collectionName: string;
embedder?: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}
) {
const { collectionName, embedder, embedderOptions } = {
...params,
};
const { applicationToken, apiEndpoint } =
params.clientParams ?? getDefaultConfig();
const keyspace = params.clientParams?.keyspace ?? DEFAULT_KEYSPACE;
const client = new DataAPIClient(applicationToken);
const db = client.db(apiEndpoint, { keyspace });
const collection = db.collection(collectionName);
return ai.defineIndexer(
{
name: `${PLUGIN_NAME}/${collectionName}`,
configSchema: AstraDBIndexerOptionsSchema,
},
async (docs) => {
let documents;
if (embedder) {
const embeddings = await Promise.all(
docs.map((doc) =>
ai.embed({
embedder,
content: doc,
options: embedderOptions,
})
)
);
documents = embeddings.flatMap((value, i) => {
const doc = docs[i];
const docEmbeddings: Embedding[] = value;
// Create one doc per docEmbedding so we can store them 1:1.
// They should be unique because the embedding metadata is
// added to the new docs.
const embeddingDocs = doc.getEmbeddingDocuments(docEmbeddings);
return docEmbeddings.map((docEmbedding, j) => {
return {
_id: Md5.hashStr(JSON.stringify(embeddingDocs[j])),
text: embeddingDocs[j].data,
$vector: docEmbedding.embedding,
metadata: embeddingDocs[j].metadata,
contentType: embeddingDocs[j].dataType,
};
});
});
} else {
documents = docs.map((doc) => ({
_id: Md5.hashStr(JSON.stringify(doc)),
text: doc.text,
$vectorize: doc.text,
metadata: doc.metadata,
}));
}
await collection.insertMany(documents);
}
);
}
function getDefaultConfig(): AstraDBClientOptions {
const maybeApiKey = process.env.ASTRA_DB_APPLICATION_TOKEN;
const maybeEndpoint = process.env.ASTRA_DB_API_ENDPOINT;
if (!maybeApiKey) {
throw new GenkitError({
status: "INVALID_ARGUMENT",
message:
"Please pass in the API key or set ASTRA_DB_APPLICATION_TOKEN environment variable.\n" +
"For more details see https://firebase.google.com/docs/genkit/plugins/astraDB",
source: PLUGIN_NAME,
});
}
if (!maybeEndpoint) {
throw new GenkitError({
status: "INVALID_ARGUMENT",
message:
"Please pass in the Astra DB API endpoint or set ASTRA_DB_API_ENDPOINT environment variable.\n" +
"For more details see https://firebase.google.com/docs/genkit/plugins/astraDB",
source: PLUGIN_NAME,
});
}
return {
applicationToken: maybeApiKey,
apiEndpoint: maybeEndpoint,
};
}