@astermind/astermind-premium
Version:
Astermind Premium - Premium ML Toolkit
180 lines • 6.65 kB
JavaScript
// hierarchical-elm.ts — Hierarchical ELM for tree-structured classification
// Coarse-to-fine classification with hierarchical decision making
import { ELM } from '@astermind/astermind-elm';
import { requireLicense } from '../core/license.js';
/**
* Hierarchical ELM for tree-structured classification
* Features:
* - Coarse-to-fine classification
* - Tree-structured decision making
* - Multi-level probability estimation
* - Efficient hierarchical search
*/
export class HierarchicalELM {
constructor(options) {
this.elms = new Map();
this.trained = false;
requireLicense(); // Premium feature - requires valid license
this.hierarchy = new Map(Object.entries(options.hierarchy));
this.rootCategories = options.rootCategories;
this.options = {
hiddenUnits: options.hiddenUnits ?? 256,
activation: options.activation ?? 'relu',
maxLen: options.maxLen ?? 100,
useTokenizer: options.useTokenizer ?? true,
};
// Initialize ELM for each level
this._initializeELMs();
}
/**
* Initialize ELMs for each level in hierarchy
*/
_initializeELMs() {
// Root level ELM
this.elms.set('root', new ELM({
useTokenizer: this.options.useTokenizer ? true : undefined,
hiddenUnits: this.options.hiddenUnits,
categories: this.rootCategories,
maxLen: this.options.maxLen,
activation: this.options.activation,
}));
// Child level ELMs
for (const [parent, children] of this.hierarchy.entries()) {
this.elms.set(parent, new ELM({
useTokenizer: this.options.useTokenizer ? true : undefined,
hiddenUnits: this.options.hiddenUnits,
categories: children,
maxLen: this.options.maxLen,
activation: this.options.activation,
}));
}
}
/**
* Train hierarchical ELM
* @param X Input features
* @param yLabels Full hierarchical paths (e.g., ['root', 'parent', 'child'])
*/
train(X, yLabels) {
// Group samples by level
const levelData = new Map();
// Root level
const rootX = [];
const rootY = [];
for (let i = 0; i < X.length; i++) {
if (yLabels[i].length > 0) {
rootX.push(X[i]);
rootY.push(this.rootCategories.indexOf(yLabels[i][0]));
}
}
levelData.set('root', { X: rootX, y: rootY });
// Child levels
for (const [parent, children] of this.hierarchy.entries()) {
const parentX = [];
const parentY = [];
for (let i = 0; i < X.length; i++) {
const path = yLabels[i];
const parentIdx = path.indexOf(parent);
if (parentIdx >= 0 && parentIdx < path.length - 1) {
const child = path[parentIdx + 1];
if (children.includes(child)) {
parentX.push(X[i]);
parentY.push(children.indexOf(child));
}
}
}
if (parentX.length > 0) {
levelData.set(parent, { X: parentX, y: parentY });
}
}
// Train each ELM
for (const [level, data] of levelData.entries()) {
const elm = this.elms.get(level);
if (elm && data.X.length > 0) {
elm.setCategories?.(level === 'root' ? this.rootCategories : this.hierarchy.get(level) || []);
elm.trainFromData?.(data.X, data.y);
}
}
this.trained = true;
}
/**
* Predict with hierarchical model
*/
predict(x, topK = 3) {
if (!this.trained) {
throw new Error('Model must be trained before prediction');
}
const XArray = Array.isArray(x[0]) ? x : [x];
const allResults = [];
for (const xi of XArray) {
const results = this._predictHierarchical(xi, topK);
allResults.push(...results);
}
return allResults;
}
/**
* Hierarchical prediction from root to leaf
*/
_predictHierarchical(x, topK) {
const rootELM = this.elms.get('root');
const rootPred = rootELM.predictFromVector?.([x], topK) || [];
const allPaths = [];
// For each root prediction, explore children
for (const rootPredItem of rootPred.slice(0, topK)) {
const rootLabel = rootPredItem.label || this.rootCategories[rootPredItem.index || 0];
const rootProb = rootPredItem.prob || 0;
// Check if root has children
const children = this.hierarchy.get(rootLabel);
if (!children || children.length === 0) {
// Leaf node
allPaths.push({
path: [rootLabel],
label: rootLabel,
prob: rootProb,
levelProbs: [rootProb],
});
continue;
}
// Predict children
const childELM = this.elms.get(rootLabel);
if (childELM) {
const childPred = childELM.predictFromVector?.([x], topK) || [];
for (const childPredItem of childPred.slice(0, topK)) {
const childLabel = childPredItem.label || children[childPredItem.index || 0];
const childProb = childPredItem.prob || 0;
const combinedProb = rootProb * childProb;
allPaths.push({
path: [rootLabel, childLabel],
label: childLabel,
prob: combinedProb,
levelProbs: [rootProb, childProb],
});
}
}
else {
// No child ELM, use root
allPaths.push({
path: [rootLabel],
label: rootLabel,
prob: rootProb,
levelProbs: [rootProb],
});
}
}
// Sort by probability and return top-K
allPaths.sort((a, b) => b.prob - a.prob);
return allPaths.slice(0, topK);
}
/**
* Get hierarchy structure
*/
getHierarchy() {
return new Map(this.hierarchy);
}
/**
* Get root categories
*/
getRootCategories() {
return [...this.rootCategories];
}
}
//# sourceMappingURL=hierarchical-elm.js.map