UNPKG

@llamaindex/core

Version:
92 lines (89 loc) 2.86 kB
import { randomUUID } from '@llamaindex/env'; import { Settings } from '../../global/dist/index.js'; import { PromptMixin } from '../../prompts/dist/index.js'; import { ObjectType, BaseNode } from '../../schema/dist/index.js'; class BaseRetriever extends PromptMixin { _updatePrompts() {} _getPrompts() { return {}; } _getPromptModules() { return {}; } constructor(){ super(), this.objectMap = new Map(); } async retrieve(params) { const cb = Settings.callbackManager; const queryBundle = typeof params === "string" ? { query: params } : params; const id = randomUUID(); cb.dispatchEvent("retrieve-start", { id, query: queryBundle }); let response = await this._retrieve(queryBundle); response = await this._handleRecursiveRetrieval(queryBundle, response); cb.dispatchEvent("retrieve-end", { id, query: queryBundle, nodes: response }); return response; } async _handleRecursiveRetrieval(params, nodes) { const retrievedNodes = []; for (const { node, score = 1.0 } of nodes){ if (node.type === ObjectType.INDEX) { const indexNode = node; const object = this.objectMap.get(indexNode.indexId); if (object !== undefined) { retrievedNodes.push(...this._retrieveFromObject(object, params, score)); } else { retrievedNodes.push({ node, score }); } } else { retrievedNodes.push({ node, score }); } } return nodes; } _retrieveFromObject(object, queryBundle, score) { if (object == null) { throw new TypeError("Object is not retrievable"); } if (typeof object !== "object") { throw new TypeError("Object is not retrievable"); } if ("node" in object && object.node instanceof BaseNode) { return [ { node: object.node, score: "score" in object && typeof object.score === "number" ? object.score : score } ]; } if (object instanceof BaseNode) { return [ { node: object, score } ]; } else { // todo: support other types // BaseQueryEngine // BaseRetriever // QueryComponent throw new TypeError("Object is not retrievable"); } } } export { BaseRetriever };