@llamaindex/core
Version:
LlamaIndex Core Module
92 lines (89 loc) • 2.86 kB
JavaScript
import { randomUUID } from '@llamaindex/env';
import { Settings } from '../../global/dist/index.js';
import { PromptMixin } from '../../prompts/dist/index.js';
import { ObjectType, BaseNode } from '../../schema/dist/index.js';
class BaseRetriever extends PromptMixin {
_updatePrompts() {}
_getPrompts() {
return {};
}
_getPromptModules() {
return {};
}
constructor(){
super(), this.objectMap = new Map();
}
async retrieve(params) {
const cb = Settings.callbackManager;
const queryBundle = typeof params === "string" ? {
query: params
} : params;
const id = randomUUID();
cb.dispatchEvent("retrieve-start", {
id,
query: queryBundle
});
let response = await this._retrieve(queryBundle);
response = await this._handleRecursiveRetrieval(queryBundle, response);
cb.dispatchEvent("retrieve-end", {
id,
query: queryBundle,
nodes: response
});
return response;
}
async _handleRecursiveRetrieval(params, nodes) {
const retrievedNodes = [];
for (const { node, score = 1.0 } of nodes){
if (node.type === ObjectType.INDEX) {
const indexNode = node;
const object = this.objectMap.get(indexNode.indexId);
if (object !== undefined) {
retrievedNodes.push(...this._retrieveFromObject(object, params, score));
} else {
retrievedNodes.push({
node,
score
});
}
} else {
retrievedNodes.push({
node,
score
});
}
}
return nodes;
}
_retrieveFromObject(object, queryBundle, score) {
if (object == null) {
throw new TypeError("Object is not retrievable");
}
if (typeof object !== "object") {
throw new TypeError("Object is not retrievable");
}
if ("node" in object && object.node instanceof BaseNode) {
return [
{
node: object.node,
score: "score" in object && typeof object.score === "number" ? object.score : score
}
];
}
if (object instanceof BaseNode) {
return [
{
node: object,
score
}
];
} else {
// todo: support other types
// BaseQueryEngine
// BaseRetriever
// QueryComponent
throw new TypeError("Object is not retrievable");
}
}
}
export { BaseRetriever };