dtamind-components
Version:
Apps integration for Dtamind. Contain Nodes and Credentials.
238 lines (203 loc) • 9.06 kB
text/typescript
import { MongoClient, type Document as MongoDBDocument } from 'mongodb'
import { MaxMarginalRelevanceSearchOptions, VectorStore } from '@langchain/core/vectorstores'
import type { EmbeddingsInterface } from '@langchain/core/embeddings'
import { chunkArray } from '@langchain/core/utils/chunk_array'
import { Document } from '@langchain/core/documents'
import { maximalMarginalRelevance } from '@langchain/core/utils/math'
import { AsyncCaller, AsyncCallerParams } from '@langchain/core/utils/async_caller'
import { getVersion } from '../../../src/utils'
export interface MongoDBAtlasVectorSearchLibArgs extends AsyncCallerParams {
readonly connectionDetails: {
readonly mongoDBConnectUrl: string
readonly databaseName: string
readonly collectionName: string
}
readonly indexName?: string
readonly textKey?: string
readonly embeddingKey?: string
readonly primaryKey?: string
}
type MongoDBAtlasFilter = {
preFilter?: MongoDBDocument
postFilterPipeline?: MongoDBDocument[]
includeEmbeddings?: boolean
} & MongoDBDocument
export class MongoDBAtlasVectorSearch extends VectorStore {
declare FilterType: MongoDBAtlasFilter
private readonly connectionDetails: {
readonly mongoDBConnectUrl: string
readonly databaseName: string
readonly collectionName: string
}
private readonly indexName: string
private readonly textKey: string
private readonly embeddingKey: string
private readonly primaryKey: string
private caller: AsyncCaller
_vectorstoreType(): string {
return 'mongodb_atlas'
}
constructor(embeddings: EmbeddingsInterface, args: MongoDBAtlasVectorSearchLibArgs) {
super(embeddings, args)
this.connectionDetails = args.connectionDetails
this.indexName = args.indexName ?? 'default'
this.textKey = args.textKey ?? 'text'
this.embeddingKey = args.embeddingKey ?? 'embedding'
this.primaryKey = args.primaryKey ?? '_id'
this.caller = new AsyncCaller(args)
}
async getClient() {
const driverInfo = { name: 'Dtamind', version: (await getVersion()).version }
const mongoClient = new MongoClient(this.connectionDetails.mongoDBConnectUrl, { driverInfo })
return mongoClient
}
async closeConnection(client: MongoClient) {
await client.close()
}
async addVectors(vectors: number[][], documents: Document[], options?: { ids?: string[] }) {
const client = await this.getClient()
const collection = client.db(this.connectionDetails.databaseName).collection(this.connectionDetails.collectionName)
const docs = vectors.map((embedding, idx) => ({
[this.textKey]: documents[idx].pageContent,
[this.embeddingKey]: embedding,
...documents[idx].metadata
}))
if (options?.ids === undefined) {
await collection.insertMany(docs)
} else {
if (options.ids.length !== vectors.length) {
throw new Error(`If provided, "options.ids" must be an array with the same length as "vectors".`)
}
const { ids } = options
for (let i = 0; i < docs.length; i += 1) {
await this.caller.call(async () => {
await collection.updateOne(
{ [this.primaryKey]: ids[i] },
{ $set: { [this.primaryKey]: ids[i], ...docs[i] } },
{ upsert: true }
)
})
}
}
await this.closeConnection(client)
return options?.ids ?? docs.map((doc) => doc[this.primaryKey])
}
async addDocuments(documents: Document[], options?: { ids?: string[] }) {
const texts = documents.map(({ pageContent }) => pageContent)
return this.addVectors(await this.embeddings.embedDocuments(texts), documents, options)
}
async similaritySearchVectorWithScore(query: number[], k: number, filter?: MongoDBAtlasFilter): Promise<[Document, number][]> {
const client = await this.getClient()
const collection = client.db(this.connectionDetails.databaseName).collection(this.connectionDetails.collectionName)
const postFilterPipeline = filter?.postFilterPipeline ?? []
const preFilter: MongoDBDocument | undefined =
filter?.preFilter || filter?.postFilterPipeline || filter?.includeEmbeddings ? filter.preFilter : filter
const removeEmbeddingsPipeline = !filter?.includeEmbeddings
? [
{
$project: {
[this.embeddingKey]: 0
}
}
]
: []
const pipeline: MongoDBDocument[] = [
{
$vectorSearch: {
queryVector: this.fixArrayPrecision(query),
index: this.indexName,
path: this.embeddingKey,
limit: k,
numCandidates: 10 * k,
...(preFilter && { filter: preFilter })
}
},
{
$set: {
score: { $meta: 'vectorSearchScore' }
}
},
...removeEmbeddingsPipeline,
...postFilterPipeline
]
const results = await collection
.aggregate(pipeline)
.map<[Document, number]>((result) => {
const { score, [this.textKey]: text, ...metadata } = result
return [new Document({ pageContent: text, metadata }), score]
})
.toArray()
await this.closeConnection(client)
return results
}
async maxMarginalRelevanceSearch(query: string, options: MaxMarginalRelevanceSearchOptions<this['FilterType']>): Promise<Document[]> {
const { k, fetchK = 20, lambda = 0.5, filter } = options
const queryEmbedding = await this.embeddings.embedQuery(query)
// preserve the original value of includeEmbeddings
const includeEmbeddingsFlag = options.filter?.includeEmbeddings || false
// update filter to include embeddings, as they will be used in MMR
const includeEmbeddingsFilter = {
...filter,
includeEmbeddings: true
}
const resultDocs = await this.similaritySearchVectorWithScore(
this.fixArrayPrecision(queryEmbedding),
fetchK,
includeEmbeddingsFilter
)
const embeddingList = resultDocs.map((doc) => doc[0].metadata[this.embeddingKey])
const mmrIndexes = maximalMarginalRelevance(queryEmbedding, embeddingList, lambda, k)
return mmrIndexes.map((idx) => {
const doc = resultDocs[idx][0]
// remove embeddings if they were not requested originally
if (!includeEmbeddingsFlag) {
delete doc.metadata[this.embeddingKey]
}
return doc
})
}
async delete(params: { ids: any[] }): Promise<void> {
const client = await this.getClient()
const collection = client.db(this.connectionDetails.databaseName).collection(this.connectionDetails.collectionName)
const CHUNK_SIZE = 50
const chunkIds: any[][] = chunkArray(params.ids, CHUNK_SIZE)
for (const chunk of chunkIds) {
await collection.deleteMany({ _id: { $in: chunk } })
}
await this.closeConnection(client)
}
static async fromTexts(
texts: string[],
metadatas: object[] | object,
embeddings: EmbeddingsInterface,
dbConfig: MongoDBAtlasVectorSearchLibArgs & { ids?: string[] }
): Promise<MongoDBAtlasVectorSearch> {
const docs: Document[] = []
for (let i = 0; i < texts.length; i += 1) {
const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas
const newDoc = new Document({
pageContent: texts[i],
metadata
})
docs.push(newDoc)
}
return MongoDBAtlasVectorSearch.fromDocuments(docs, embeddings, dbConfig)
}
static async fromDocuments(
docs: Document[],
embeddings: EmbeddingsInterface,
dbConfig: MongoDBAtlasVectorSearchLibArgs & { ids?: string[] }
): Promise<MongoDBAtlasVectorSearch> {
const instance = new this(embeddings, dbConfig)
await instance.addDocuments(docs, { ids: dbConfig.ids })
return instance
}
fixArrayPrecision(array: number[]) {
return array.map((value) => {
if (Number.isInteger(value)) {
return value + 0.000000000000001
}
return value
})
}
}