UNPKG

@astermind/astermind-premium

Version:

Astermind Premium - Premium ML Toolkit

155 lines 5.47 kB
// transfer-learning-elm.ts — Transfer Learning ELM // Pre-trained ELM adaptation, domain adaptation, and few-shot learning import { ELM } from '@astermind/astermind-elm'; import { requireLicense } from '../core/license.js'; /** * Transfer Learning ELM * Features: * - Pre-trained model adaptation * - Domain adaptation * - Few-shot learning * - Fine-tuning capabilities */ export class TransferLearningELM { constructor(options) { this.sourceModel = null; this.trained = false; requireLicense(); // Premium feature - requires valid license this.categories = options.categories; this.sourceModel = options.sourceModel || null; this.options = { categories: options.categories, sourceModel: this.sourceModel, freezeBase: options.freezeBase ?? false, fineTuneLayers: options.fineTuneLayers ?? 1, hiddenUnits: options.hiddenUnits ?? 256, activation: options.activation ?? 'relu', maxLen: options.maxLen ?? 100, useTokenizer: options.useTokenizer ?? true, }; this.elm = new ELM({ useTokenizer: this.options.useTokenizer ? true : undefined, hiddenUnits: this.options.hiddenUnits, categories: this.options.categories, maxLen: this.options.maxLen, activation: this.options.activation, }); // Transfer weights from source model if available if (this.sourceModel) { this._transferWeights(); } } /** * Transfer weights from source model */ _transferWeights() { if (!this.sourceModel) return; const sourceModelData = this.sourceModel.model; const targetModel = this.elm.model; if (!sourceModelData || !targetModel) return; // Transfer hidden layer weights if dimensions match if (sourceModelData.W && targetModel.W) { const sourceW = sourceModelData.W; const targetW = targetModel.W; // Copy matching dimensions for (let i = 0; i < Math.min(sourceW.length, targetW.length); i++) { for (let j = 0; j < Math.min(sourceW[i]?.length || 0, targetW[i]?.length || 0); j++) { if (!this.options.freezeBase) { targetW[i][j] = sourceW[i][j]; } } } } // Transfer biases if available if (sourceModelData.b && targetModel.b) { const sourceB = sourceModelData.b; const targetB = targetModel.b; for (let i = 0; i < Math.min(sourceB.length, targetB.length); i++) { if (!this.options.freezeBase) { targetB[i] = sourceB[i]; } } } } /** * Train with transfer learning * @param X Target domain features * @param y Target domain labels */ train(X, y) { // Prepare labels const labelIndices = y.map(label => typeof label === 'number' ? label : this.options.categories.indexOf(label)); // If source model exists and we're not freezing, fine-tune if (this.sourceModel && !this.options.freezeBase) { // Fine-tune: train on new data with transferred weights this.elm.setCategories?.(this.options.categories); this.elm.trainFromData?.(X, labelIndices, { reuseWeights: true, // Reuse transferred weights }); } else { // Standard training this.elm.setCategories?.(this.options.categories); this.elm.trainFromData?.(X, labelIndices); } this.trained = true; } /** * Few-shot learning: train with very few examples */ fewShotTrain(X, y, shots = 5) { if (!this.sourceModel) { throw new Error('Few-shot learning requires a pre-trained source model'); } // Use only a few examples const limitedX = X.slice(0, shots); const limitedY = y.slice(0, shots); // Fine-tune on limited data this.train(limitedX, limitedY); } /** * Predict with transferred model */ predict(X, topK = 3) { if (!this.trained) { throw new Error('Model must be trained before prediction'); } const XArray = Array.isArray(X[0]) ? X : [X]; const results = []; for (const x of XArray) { const preds = this.elm.predictFromVector?.([x], topK) || []; for (const pred of preds.slice(0, topK)) { results.push({ label: pred.label || this.options.categories[pred.index || 0], prob: pred.prob || 0, }); } } return results; } /** * Load pre-trained model */ loadSourceModel(model) { this.sourceModel = model; this._transferWeights(); } /** * Export current model for use as source in other transfers */ exportModel() { return { model: this.elm.model, categories: this.options.categories, config: { hiddenUnits: this.options.hiddenUnits, activation: this.options.activation, }, }; } } //# sourceMappingURL=transfer-learning-elm.js.map