@caleblawson/rag
Version:
The Retrieval-Augmented Generation (RAG) module contains document processing and embedding utilities.
112 lines (103 loc) • 3.72 kB
text/typescript
import type { MastraLanguageModel } from '@mastra/core/agent';
import { defaultKeywordExtractPrompt, PromptTemplate } from '../prompts';
import type { KeywordExtractPrompt } from '../prompts';
import type { BaseNode } from '../schema';
import { TextNode } from '../schema';
import { BaseExtractor } from './base';
import { baseLLM } from './types';
import type { KeywordExtractArgs } from './types';
type ExtractKeyword = {
/**
* Comma-separated keywords extracted from the node. May be empty if extraction fails.
*/
excerptKeywords: string;
};
/**
* Extract keywords from a list of nodes.
*/
export class KeywordExtractor extends BaseExtractor {
llm: MastraLanguageModel;
keywords: number = 5;
promptTemplate: KeywordExtractPrompt;
/**
* Constructor for the KeywordExtractor class.
* @param {MastraLanguageModel} llm MastraLanguageModel instance.
* @param {number} keywords Number of keywords to extract.
* @param {string} [promptTemplate] Optional custom prompt template (must include {context})
* @throws {Error} If keywords is less than 1.
*/
constructor(options?: KeywordExtractArgs) {
if (options?.keywords && options.keywords < 1) throw new Error('Keywords must be greater than 0');
super();
this.llm = options?.llm ?? baseLLM;
this.keywords = options?.keywords ?? 5;
this.promptTemplate = options?.promptTemplate
? new PromptTemplate({
templateVars: ['context', 'maxKeywords'],
template: options.promptTemplate,
})
: defaultKeywordExtractPrompt;
}
/**
*
* @param node Node to extract keywords from.
* @returns Keywords extracted from the node.
*/
/**
* Extract keywords from a node. Returns an object with a comma-separated string of keywords, or an empty string if extraction fails.
* Adds error handling for malformed/empty LLM output.
*/
async extractKeywordsFromNodes(node: BaseNode): Promise<ExtractKeyword> {
const text = node.getContent();
if (!text || text.trim() === '') {
return { excerptKeywords: '' };
}
if (this.isTextNodeOnly && !(node instanceof TextNode)) {
return { excerptKeywords: '' };
}
let keywords = '';
try {
const completion = await this.llm.doGenerate({
inputFormat: 'messages',
mode: { type: 'regular' },
prompt: [
{
role: 'user',
content: [
{
type: 'text',
text: this.promptTemplate.format({
context: node.getContent(),
maxKeywords: this.keywords.toString(),
}),
},
],
},
],
});
if (typeof completion.text === 'string') {
keywords = completion.text.trim();
} else {
console.warn('Keyword extraction LLM output was not a string:', completion.text);
}
} catch (err) {
console.warn('Keyword extraction failed:', err);
}
return { excerptKeywords: keywords };
}
/**
*
* @param nodes Nodes to extract keywords from.
* @returns Keywords extracted from the nodes.
*/
/**
* Extract keywords from an array of nodes. Always returns an array (may be empty).
* @param nodes Nodes to extract keywords from.
* @returns Array of keyword extraction results.
*/
async extract(nodes: BaseNode[]): Promise<Array<ExtractKeyword>> {
if (!Array.isArray(nodes) || nodes.length === 0) return [];
const results = await Promise.all(nodes.map(node => this.extractKeywordsFromNodes(node)));
return results;
}
}