@mastra/rag
Version:
The Retrieval-Augmented Generation (RAG) module contains document processing and embedding utilities.
336 lines (288 loc) • 10.1 kB
text/typescript
import { TitleExtractor, SummaryExtractor, QuestionsAnsweredExtractor, KeywordExtractor } from './extractors';
import type { BaseNode } from './schema';
import { Document as Chunk, NodeRelationship, ObjectType } from './schema';
import { CharacterTransformer, RecursiveCharacterTransformer } from './transformers/character';
import { HTMLHeaderTransformer, HTMLSectionTransformer } from './transformers/html';
import { RecursiveJsonTransformer } from './transformers/json';
import { LatexTransformer } from './transformers/latex';
import { MarkdownHeaderTransformer, MarkdownTransformer } from './transformers/markdown';
import { SemanticMarkdownTransformer } from './transformers/semantic-markdown';
import { SentenceTransformer } from './transformers/sentence';
import { TokenTransformer } from './transformers/token';
import type {
ChunkParams,
ChunkStrategy,
ExtractParams,
HTMLChunkOptions,
RecursiveChunkOptions,
CharacterChunkOptions,
TokenChunkOptions,
MarkdownChunkOptions,
SemanticMarkdownChunkOptions,
JsonChunkOptions,
LatexChunkOptions,
SentenceChunkOptions,
StrategyOptions,
} from './types';
import { validateChunkParams } from './validation';
export class MDocument {
private chunks: Chunk[];
private type: string; // e.g., 'text', 'html', 'markdown', 'json'
constructor({ docs, type }: { docs: { text: string; metadata?: Record<string, any> }[]; type: string }) {
this.chunks = docs.map(d => {
return new Chunk({ text: d.text, metadata: d.metadata });
});
this.type = type;
}
async extractMetadata({ title, summary, questions, keywords }: ExtractParams): Promise<MDocument> {
const transformations = [];
if (typeof summary !== 'undefined') {
transformations.push(new SummaryExtractor(typeof summary === 'boolean' ? {} : summary));
}
if (typeof questions !== 'undefined') {
transformations.push(new QuestionsAnsweredExtractor(typeof questions === 'boolean' ? {} : questions));
}
if (typeof keywords !== 'undefined') {
transformations.push(new KeywordExtractor(typeof keywords === 'boolean' ? {} : keywords));
}
if (typeof title !== 'undefined') {
transformations.push(new TitleExtractor(typeof title === 'boolean' ? {} : title));
this.chunks = this.chunks.map(doc =>
doc?.metadata?.docId
? new Chunk({
...doc,
relationships: {
[NodeRelationship.SOURCE]: {
nodeId: doc.metadata.docId,
nodeType: ObjectType.DOCUMENT,
metadata: doc.metadata,
},
},
})
: doc,
);
}
let nodes: BaseNode[] = this.chunks;
for (const extractor of transformations) {
nodes = await extractor.processNodes(nodes);
}
this.chunks = this.chunks.map((doc, i) => {
return new Chunk({
text: doc.text,
metadata: {
...doc.metadata,
...(nodes?.[i]?.metadata || {}),
},
});
});
return this;
}
static fromText(text: string, metadata?: Record<string, any>): MDocument {
return new MDocument({
docs: [
{
text,
metadata,
},
],
type: 'text',
});
}
static fromHTML(html: string, metadata?: Record<string, any>): MDocument {
return new MDocument({
docs: [
{
text: html,
metadata,
},
],
type: 'html',
});
}
static fromMarkdown(markdown: string, metadata?: Record<string, any>): MDocument {
return new MDocument({
docs: [
{
text: markdown,
metadata,
},
],
type: 'markdown',
});
}
static fromJSON(jsonString: string, metadata?: Record<string, any>): MDocument {
return new MDocument({
docs: [
{
text: jsonString,
metadata,
},
],
type: 'json',
});
}
private defaultStrategy(): ChunkStrategy {
switch (this.type) {
case 'html':
return 'html';
case 'markdown':
return 'markdown';
case 'json':
return 'json';
case 'latex':
return 'latex';
default:
return 'recursive';
}
}
private _strategyMap?: { [S in ChunkStrategy]: (options?: StrategyOptions[S]) => Promise<void> };
private get strategyMap() {
if (!this._strategyMap) {
this._strategyMap = {
recursive: options => this.chunkRecursive(options),
character: options => this.chunkCharacter(options),
token: options => this.chunkToken(options),
markdown: options => this.chunkMarkdown(options),
html: options => this.chunkHTML(options),
json: options => this.chunkJSON(options),
latex: options => this.chunkLatex(options),
sentence: options => this.chunkSentence(options),
'semantic-markdown': options => this.chunkSemanticMarkdown(options),
};
}
return this._strategyMap;
}
private async chunkBy<K extends ChunkStrategy>(strategy: K, options?: StrategyOptions[K]): Promise<void> {
const chunkingFunc = this.strategyMap[strategy];
if (chunkingFunc) {
await chunkingFunc(options);
} else {
throw new Error(`Unknown strategy: ${strategy}`);
}
}
async chunkRecursive(options?: RecursiveChunkOptions): Promise<void> {
if (options?.language) {
const rt = RecursiveCharacterTransformer.fromLanguage(options.language, options);
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
return;
}
const rt = new RecursiveCharacterTransformer(options);
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
}
async chunkCharacter(options?: CharacterChunkOptions): Promise<void> {
const rt = new CharacterTransformer({
...options,
separator: options?.separator,
isSeparatorRegex: options?.isSeparatorRegex,
});
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
}
async chunkHTML(options?: HTMLChunkOptions): Promise<void> {
if (options?.headers?.length) {
const rt = new HTMLHeaderTransformer(options as HTMLChunkOptions & { headers: [string, string][] });
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
return;
}
if (options?.sections?.length) {
const rt = new HTMLSectionTransformer(options as HTMLChunkOptions & { sections: [string, string][] });
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
return;
}
throw new Error('HTML chunking requires either headers or sections to be specified');
}
async chunkJSON(options?: JsonChunkOptions): Promise<void> {
if (!options?.maxSize) {
throw new Error('JSON chunking requires maxSize to be specified');
}
const rt = new RecursiveJsonTransformer({
maxSize: options?.maxSize,
minSize: options?.minSize,
});
const textSplit = rt.transformDocuments({
documents: this.chunks,
ensureAscii: options?.ensureAscii,
convertLists: options?.convertLists,
});
this.chunks = textSplit;
}
async chunkLatex(options?: LatexChunkOptions): Promise<void> {
const rt = new LatexTransformer(options);
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
}
async chunkToken(options?: TokenChunkOptions): Promise<void> {
const rt = TokenTransformer.fromTikToken({
options,
encodingName: options?.encodingName,
modelName: options?.modelName,
});
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
}
async chunkMarkdown(options?: MarkdownChunkOptions): Promise<void> {
if (options?.headers) {
const rt = new MarkdownHeaderTransformer(options.headers, options?.returnEachLine, options?.stripHeaders);
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
return;
}
const rt = new MarkdownTransformer(options);
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
}
async chunkSentence(options?: SentenceChunkOptions): Promise<void> {
if (!options?.maxSize) {
throw new Error('Sentence chunking requires maxSize to be specified');
}
const rt = new SentenceTransformer({
minSize: options?.minSize,
maxSize: options?.maxSize,
targetSize: options?.targetSize,
overlap: options?.overlap,
sentenceEnders: options?.sentenceEnders,
fallbackToWords: options?.fallbackToWords,
fallbackToCharacters: options?.fallbackToCharacters,
keepSeparator: options?.keepSeparator,
lengthFunction: options?.lengthFunction,
addStartIndex: options?.addStartIndex,
stripWhitespace: options?.stripWhitespace,
});
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
}
async chunkSemanticMarkdown(options?: SemanticMarkdownChunkOptions): Promise<void> {
const rt = SemanticMarkdownTransformer.fromTikToken({
options,
encodingName: options?.encodingName,
modelName: options?.modelName,
});
const textSplit = rt.transformDocuments(this.chunks);
this.chunks = textSplit;
}
async chunk(params?: ChunkParams): Promise<Chunk[]> {
const { strategy: passedStrategy, extract, ...chunkOptions } = params || {};
// Determine the default strategy based on type if not specified
const strategy = passedStrategy || this.defaultStrategy();
validateChunkParams(strategy, chunkOptions);
// Apply the appropriate chunking strategy
await this.chunkBy(strategy, chunkOptions);
if (extract) {
await this.extractMetadata(extract);
}
return this.chunks;
}
getDocs(): Chunk[] {
return this.chunks;
}
getText(): string[] {
return this.chunks.map(doc => doc.text);
}
getMetadata(): Record<string, any>[] {
return this.chunks.map(doc => doc.metadata);
}
}