UNPKG

encog

Version:

Encog is a NodeJs ES6 framework based on the Encog Machine Learning Framework by Jeff Heaton, plus some the of basic data manipulation helpers.

285 lines (240 loc) 7.92 kB
const BasicTraining = require(PATHS.TRAINING + 'basic'); const ErrorCalculation = require(PATHS.ERROR_CALCULATION + 'errorCalculation'); /** * Provides basic propagation functions to other trainers. */ class FreeformPropagationTraining extends BasicTraining { /** * Construct the trainer. * @param theNetwork {FreeformNetwork} The network to train. * @param theInput {Array} The training data. * @param theOutput {Array} The training data. */ constructor(theNetwork, theInput, theOutput) { super(); this.FLAT_SPOT_CONST = 0.1; this.network = theNetwork; this.input = theInput; this.output = theOutput; this.batchSize = 0; this.fixFlatSopt = true; this.visited = []; } /** * Calculate the gradient for a neuron. * @param toNeuron {FreeformNeuron} The neuron to calculate for. */ calculateNeuronGradient(toNeuron) { // Only calculate if layer has inputs, because we've already handled the // output // neurons, this means a hidden layer. if (toNeuron.getInputSummation() != null) { // between the layer deltas between toNeuron and the neurons that // feed toNeuron. // also calculate all inbound gradeints to toNeuron for (let connection of toNeuron.getInputSummation().list()) { // calculate the gradient const gradient = connection.getSource().getActivation() * toNeuron.getTempTraining(0); connection.addTempTraining(0, gradient); // calculate the next layer delta const fromNeuron = connection.getSource(); let sum = 0; for (let toConnection of fromNeuron.getOutputs()) { sum += toConnection.getTarget().getTempTraining(0) * toConnection.getWeight(); } const neuronOutput = fromNeuron.getActivation(); const neuronSum = fromNeuron.getSum(); let deriv = toNeuron.getInputSummation().getActivationFunction().derivativeFunction(neuronSum, neuronOutput); if (this.fixFlatSopt && (toNeuron.getInputSummation().getActivationFunction().constructor.name === 'ActivationSigmoid')) { deriv += this.FLAT_SPOT_CONST; } fromNeuron.setTempTraining(0, sum * deriv); } // recurse to the next level for (let connection of toNeuron.getInputSummation().list()) { let fromNeuron = connection.getSource(); this.calculateNeuronGradient(fromNeuron); } } } /** * Calculate the output delta for a neuron, given its difference. * Only used for output neurons. * @param neuron {FreeformNeuron} * @param diff {Number} */ calculateOutputDelta(neuron, diff) { const neuronOutput = neuron.getActivation(); const neuronSum = neuron.getInputSummation().getSum(); let deriv = neuron.getInputSummation().getActivationFunction().derivativeFunction(neuronSum, neuronOutput); if (this.fixFlatSopt && (neuron.getInputSummation().getActivationFunction().constructor.name === 'ActivationSigmoid')) { deriv += this.FLAT_SPOT_CONST; } neuron.setTempTraining(0, deriv * diff); } /** * {@inheritDoc} */ canContinue() { return false; } /** * {@inheritDoc} */ finishTraining() { this.network.tempTrainingClear(); } /** * {@inheritDoc} */ getError() { return this.error; } /** * {@inheritDoc} */ getImplementationType() { return TrainingImplementationType.Iterative; } /** * {@inheritDoc} */ getIteration() { return this.iterationCount; } /** * @return {Boolean} True, if we are fixing the flat spot problem. */ isFixFlatSopt() { return this.fixFlatSopt; } /** * {@inheritDoc} */ iteration(count = 1) { for (let i = 0; i < count; i++) { this.preIteration(); this.iterationCount++; this.network.clearContext(); if (this.batchSize == 0) { this.processPureBatch(); } else { this.processBatches(); } this.postIteration(); } } /** * Process training for pure batch mode (one single batch). */ processPureBatch() { const errorCalc = new ErrorCalculation(); this.visited = []; let input; let ideal; let actual; let sig; for (let j = 0; j < this.input.length; j++) { input = this.input[j]; ideal = this.output[j]; actual = this.network.compute(input); sig = 1; errorCalc.updateError(actual, ideal, sig); for (let i = 0; i < this.network.getOutputCount(); i++) { const diff = (ideal[i] - actual[i]) * sig; const neuron = this.network.getOutputLayer().getNeurons()[i]; this.calculateOutputDelta(neuron, diff); this.calculateNeuronGradient(neuron); } } // Set the overall error. this.setError(errorCalc.calculate()); // Learn for all data. this.learn(); } /** * Process training batches. */ processBatches() { let lastLearn = 0; const errorCalc = new ErrorCalculation(); this.visited = []; let input; let ideal; let actual; let sig; for (let j = 0; j < this.input.length; j++) { input = this.input[j]; ideal = this.output[j]; actual = this.network.compute(input); sig = 1; errorCalc.updateError(actual.getData(), ideal.getData(), sig); for (let i = 0; i < this.network.getOutputCount(); i++) { const diff = (ideal.getData(i) - actual.getData(i)) * sig; const neuron = this.network.getOutputLayer().getNeurons()[i]; this.calculateOutputDelta(neuron, diff); this.calculateNeuronGradient(neuron); } // Are we at the end of a batch. lastLearn++; if (lastLearn >= this.batchSize) { lastLearn = 0; this.learn(); } } // Handle any remaining data. if (lastLearn > 0) { this.learn(); } // Set the overall error. this.setError(errorCalc.calculate()); } /** * Learn for the entire network. */ learn() { const that = this; this.network.performConnectionTask((connection)=> { that.learnConnection(connection); connection.setTempTraining(0, 0); }); } /** * Learn for a single connection. * @param connection The connection to learn from. */ learnConnection(connection){} /** * {@inheritDoc} */ setError(theError) { this.error = theError; } /** * Set if we should fix the flat spot problem. * @param fixFlatSopt {Boolean} True, if we should fix the flat spot problem. */ setFixFlatSopt(fixFlatSopt) { this.fixFlatSopt = fixFlatSopt; } /** * {@inheritDoc} */ setIteration(iteration) { this.iterationCount = iteration; } /** * @return The batch size. */ getBatchSize() { return batchSize; } /** * Set the batch size. * @param batchSize {Number} The batch size. */ setBatchSize(batchSize) { this.batchSize = batchSize; } } module.exports = FreeformPropagationTraining;