UNPKG

@astermind/astermind-premium

Version:

Astermind Premium - Premium ML Toolkit

180 lines 6.65 kB
// hierarchical-elm.ts — Hierarchical ELM for tree-structured classification // Coarse-to-fine classification with hierarchical decision making import { ELM } from '@astermind/astermind-elm'; import { requireLicense } from '../core/license.js'; /** * Hierarchical ELM for tree-structured classification * Features: * - Coarse-to-fine classification * - Tree-structured decision making * - Multi-level probability estimation * - Efficient hierarchical search */ export class HierarchicalELM { constructor(options) { this.elms = new Map(); this.trained = false; requireLicense(); // Premium feature - requires valid license this.hierarchy = new Map(Object.entries(options.hierarchy)); this.rootCategories = options.rootCategories; this.options = { hiddenUnits: options.hiddenUnits ?? 256, activation: options.activation ?? 'relu', maxLen: options.maxLen ?? 100, useTokenizer: options.useTokenizer ?? true, }; // Initialize ELM for each level this._initializeELMs(); } /** * Initialize ELMs for each level in hierarchy */ _initializeELMs() { // Root level ELM this.elms.set('root', new ELM({ useTokenizer: this.options.useTokenizer ? true : undefined, hiddenUnits: this.options.hiddenUnits, categories: this.rootCategories, maxLen: this.options.maxLen, activation: this.options.activation, })); // Child level ELMs for (const [parent, children] of this.hierarchy.entries()) { this.elms.set(parent, new ELM({ useTokenizer: this.options.useTokenizer ? true : undefined, hiddenUnits: this.options.hiddenUnits, categories: children, maxLen: this.options.maxLen, activation: this.options.activation, })); } } /** * Train hierarchical ELM * @param X Input features * @param yLabels Full hierarchical paths (e.g., ['root', 'parent', 'child']) */ train(X, yLabels) { // Group samples by level const levelData = new Map(); // Root level const rootX = []; const rootY = []; for (let i = 0; i < X.length; i++) { if (yLabels[i].length > 0) { rootX.push(X[i]); rootY.push(this.rootCategories.indexOf(yLabels[i][0])); } } levelData.set('root', { X: rootX, y: rootY }); // Child levels for (const [parent, children] of this.hierarchy.entries()) { const parentX = []; const parentY = []; for (let i = 0; i < X.length; i++) { const path = yLabels[i]; const parentIdx = path.indexOf(parent); if (parentIdx >= 0 && parentIdx < path.length - 1) { const child = path[parentIdx + 1]; if (children.includes(child)) { parentX.push(X[i]); parentY.push(children.indexOf(child)); } } } if (parentX.length > 0) { levelData.set(parent, { X: parentX, y: parentY }); } } // Train each ELM for (const [level, data] of levelData.entries()) { const elm = this.elms.get(level); if (elm && data.X.length > 0) { elm.setCategories?.(level === 'root' ? this.rootCategories : this.hierarchy.get(level) || []); elm.trainFromData?.(data.X, data.y); } } this.trained = true; } /** * Predict with hierarchical model */ predict(x, topK = 3) { if (!this.trained) { throw new Error('Model must be trained before prediction'); } const XArray = Array.isArray(x[0]) ? x : [x]; const allResults = []; for (const xi of XArray) { const results = this._predictHierarchical(xi, topK); allResults.push(...results); } return allResults; } /** * Hierarchical prediction from root to leaf */ _predictHierarchical(x, topK) { const rootELM = this.elms.get('root'); const rootPred = rootELM.predictFromVector?.([x], topK) || []; const allPaths = []; // For each root prediction, explore children for (const rootPredItem of rootPred.slice(0, topK)) { const rootLabel = rootPredItem.label || this.rootCategories[rootPredItem.index || 0]; const rootProb = rootPredItem.prob || 0; // Check if root has children const children = this.hierarchy.get(rootLabel); if (!children || children.length === 0) { // Leaf node allPaths.push({ path: [rootLabel], label: rootLabel, prob: rootProb, levelProbs: [rootProb], }); continue; } // Predict children const childELM = this.elms.get(rootLabel); if (childELM) { const childPred = childELM.predictFromVector?.([x], topK) || []; for (const childPredItem of childPred.slice(0, topK)) { const childLabel = childPredItem.label || children[childPredItem.index || 0]; const childProb = childPredItem.prob || 0; const combinedProb = rootProb * childProb; allPaths.push({ path: [rootLabel, childLabel], label: childLabel, prob: combinedProb, levelProbs: [rootProb, childProb], }); } } else { // No child ELM, use root allPaths.push({ path: [rootLabel], label: rootLabel, prob: rootProb, levelProbs: [rootProb], }); } } // Sort by probability and return top-K allPaths.sort((a, b) => b.prob - a.prob); return allPaths.slice(0, topK); } /** * Get hierarchy structure */ getHierarchy() { return new Map(this.hierarchy); } /** * Get root categories */ getRootCategories() { return [...this.rootCategories]; } } //# sourceMappingURL=hierarchical-elm.js.map