UNPKG

@caleblawson/rag

Version:

The Retrieval-Augmented Generation (RAG) module contains document processing and embedding utilities.

119 lines (102 loc) 3.87 kB
import type { MastraLanguageModel } from '@mastra/core/agent'; import { PromptTemplate, defaultSummaryPrompt } from '../prompts'; import type { SummaryPrompt } from '../prompts'; import type { BaseNode } from '../schema'; import { TextNode } from '../schema'; import { BaseExtractor } from './base'; import { baseLLM, STRIP_REGEX } from './types'; import type { SummaryExtractArgs } from './types'; type ExtractSummary = { sectionSummary?: string; prevSectionSummary?: string; nextSectionSummary?: string; }; /** * Summarize an array of nodes using a custom LLM. * * @param nodes Array of node-like objects * @param options Summary extraction options * @returns Array of summary results */ export class SummaryExtractor extends BaseExtractor { private llm: MastraLanguageModel; summaries: string[]; promptTemplate: SummaryPrompt; private selfSummary: boolean; private prevSummary: boolean; private nextSummary: boolean; constructor(options?: SummaryExtractArgs) { const summaries = options?.summaries ?? ['self']; if (summaries && !summaries.some(s => ['self', 'prev', 'next'].includes(s))) throw new Error("Summaries must be one of 'self', 'prev', 'next'"); super(); this.llm = options?.llm ?? baseLLM; this.summaries = summaries; this.promptTemplate = options?.promptTemplate ? new PromptTemplate({ templateVars: ['context'], template: options.promptTemplate, }) : defaultSummaryPrompt; this.selfSummary = summaries?.includes('self') ?? false; this.prevSummary = summaries?.includes('prev') ?? false; this.nextSummary = summaries?.includes('next') ?? false; } /** * Extract summary from a node. * @param {BaseNode} node Node to extract summary from. * @returns {Promise<string>} Summary extracted from the node. */ async generateNodeSummary(node: BaseNode): Promise<string> { const text = node.getContent(); if (!text || text.trim() === '') { return ''; } if (this.isTextNodeOnly && !(node instanceof TextNode)) { return ''; } const context = node.getContent(); const prompt = this.promptTemplate.format({ context, }); const result = await this.llm.doGenerate({ inputFormat: 'messages', mode: { type: 'regular' }, prompt: [ { role: 'user', content: [{ type: 'text', text: prompt }], }, ], }); let summary = ''; if (typeof result.text === 'string') { summary = result.text.trim(); } else { console.warn('Summary extraction LLM output was not a string:', result.text); } return summary.replace(STRIP_REGEX, ''); } /** * Extract summaries from a list of nodes. * @param {BaseNode[]} nodes Nodes to extract summaries from. * @returns {Promise<ExtractSummary[]>} Summaries extracted from the nodes. */ async extract(nodes: BaseNode[]): Promise<ExtractSummary[]> { if (!nodes.every(n => n instanceof TextNode)) throw new Error('Only `TextNode` is allowed for `Summary` extractor'); const nodeSummaries = await Promise.all(nodes.map(node => this.generateNodeSummary(node))); const metadataList: ExtractSummary[] = nodes.map(() => ({})); for (let i = 0; i < nodes.length; i++) { if (i > 0 && this.prevSummary && nodeSummaries[i - 1]) { metadataList[i]!['prevSectionSummary'] = nodeSummaries[i - 1]; } if (i < nodes.length - 1 && this.nextSummary && nodeSummaries[i + 1]) { metadataList[i]!['nextSectionSummary'] = nodeSummaries[i + 1]; } if (this.selfSummary && nodeSummaries[i]) { metadataList[i]!['sectionSummary'] = nodeSummaries[i]; } } return metadataList; } }