UNPKG

brainjs

Version:
605 lines (528 loc) 20.1 kB
var _ = require("underscore"), lookup = require("./lookup"), Writable = require('stream').Writable, inherits = require('inherits'); var NeuralNetwork = function (options) { options = options || {}; this.learningRate = options.learningRate || 0.3; this.momentum = options.momentum || 0.1; this.hiddenSizes = options.hiddenLayers; this.binaryThresh = options.binaryThresh || 0.5; } NeuralNetwork.prototype = { initialize: function (sizes) { this.sizes = sizes; this.outputLayer = this.sizes.length - 1; this.biases = []; // weights for bias nodes this.weights = []; this.outputs = []; // state for training this.deltas = []; this.changes = []; // for momentum this.errors = []; for (var layer = 0; layer <= this.outputLayer; layer++) { var size = this.sizes[layer]; this.deltas[layer] = zeros(size); this.errors[layer] = zeros(size); this.outputs[layer] = zeros(size); if (layer > 0) { this.biases[layer] = randos(size); this.weights[layer] = new Array(size); this.changes[layer] = new Array(size); for (var node = 0; node < size; node++) { var prevSize = this.sizes[layer - 1]; this.weights[layer][node] = randos(prevSize); this.changes[layer][node] = zeros(prevSize); } } } }, run: function (input) { if (this.inputLookup) { input = lookup.toArray(this.inputLookup, input); } var output = this.runInput(input); if (this.outputLookup) { output = lookup.toHash(this.outputLookup, output); } return output; }, runInput: function (input) { this.outputs[0] = input; // set output state of input layer for (var layer = 1; layer <= this.outputLayer; layer++) { for (var node = 0; node < this.sizes[layer]; node++) { var weights = this.weights[layer][node]; var sum = this.biases[layer][node]; for (var k = 0; k < weights.length; k++) { sum += weights[k] * input[k]; } this.outputs[layer][node] = 1 / (1 + Math.exp(-sum)); } var output = input = this.outputs[layer]; } return output; }, train: function (data, options) { data = this.formatData(data); options = options || {}; var iterations = options.iterations || 20000; var errorThresh = options.errorThresh || 0.005; var log = options.log ? (_.isFunction(options.log) ? options.log : console.log) : false; var logPeriod = options.logPeriod || 10; var learningRate = options.learningRate || this.learningRate || 0.3; var callback = options.callback; var callbackPeriod = options.callbackPeriod || 10; var inputSize = data[0].input.length; var outputSize = data[0].output.length; var hiddenSizes = this.hiddenSizes; if (!hiddenSizes) { hiddenSizes = [Math.max(3, Math.floor(inputSize / 2))]; } var sizes = _([inputSize, hiddenSizes, outputSize]).flatten(); this.initialize(sizes); var error = 1; for (var i = 0; i < iterations && error > errorThresh; i++) { var sum = 0; for (var j = 0; j < data.length; j++) { var err = this.trainPattern(data[j].input, data[j].output, learningRate); sum += err; } error = sum / data.length; if (log && (i % logPeriod == 0)) { log("iterations:", i, "training error:", error); } if (callback && (i % callbackPeriod == 0)) { callback({error: error, iterations: i}); } } return { error: error, iterations: i }; }, trainPattern: function (input, target, learningRate) { learningRate = learningRate || this.learningRate; // forward propogate this.runInput(input); // back propogate this.calculateDeltas(target); this.adjustWeights(learningRate); var error = mse(this.errors[this.outputLayer]); return error; }, /** * Trains a neural network based on an oracle function, which is expected to return the error for each run of the process, for each output value. * @param {Array} input The input * @param {Function} targetFunction The oracle function, which takes an array of outputs and computes the error. * @param {type} learningRate The learning rate * @returns {Number|mse.sum|NeuralNetwork.prototype.trainFunction.error} */ trainFunction: function (input, targetFunction, learningRate,optionalCallback) { learningRate = learningRate || this.learningRate; // forward propogate this.runInput(input); var error; var thisptr = this; // back propogate this.calculateDeltasForFunction(targetFunction,function() { thisptr.adjustWeights(learningRate); error = mse(thisptr.errors[thisptr.outputLayer]); if(optionalCallback) { optionalCallback(error); } }); if(error != undefined) { return error; } }, calculateDeltas: function (target) { for (var layer = this.outputLayer; layer >= 0; layer--) { for (var node = 0; node < this.sizes[layer]; node++) { var output = this.outputs[layer][node]; var error = 0; if (layer == this.outputLayer) { error = target[node] - output; } else { var deltas = this.deltas[layer + 1]; for (var k = 0; k < deltas.length; k++) { error += deltas[k] * this.weights[layer + 1][k][node]; } } this.errors[layer][node] = error; this.deltas[layer][node] = error * output * (1 - output); } } }, calculateDeltasForFunction: function (target,callback) { var thisptr = this; var finish = function (target) { for (var layer = thisptr.outputLayer; layer >= 0; layer--) { for (var node = 0; node < thisptr.sizes[layer]; node++) { var output = thisptr.outputs[layer][node]; var error = 0; if (layer == thisptr.outputLayer) { error = target[node] || 0; } else { var deltas = thisptr.deltas[layer + 1]; for (var k = 0; k < deltas.length; k++) { error += deltas[k] * thisptr.weights[layer + 1][k][node]; } } thisptr.errors[layer][node] = error; thisptr.deltas[layer][node] = error * output * (1 - output); } } callback(); }; target = target(this.outputs[this.outputLayer], finish); if (target != undefined) { finish(target); } }, adjustWeights: function (learningRate) { for (var layer = 1; layer <= this.outputLayer; layer++) { var incoming = this.outputs[layer - 1]; for (var node = 0; node < this.sizes[layer]; node++) { var delta = this.deltas[layer][node]; for (var k = 0; k < incoming.length; k++) { var change = this.changes[layer][node][k]; change = (learningRate * delta * incoming[k]) + (this.momentum * change); this.changes[layer][node][k] = change; this.weights[layer][node][k] += change; } this.biases[layer][node] += learningRate * delta; } } }, formatData: function (data) { if (!_.isArray(data)) { // turn stream datum into array var tmp = []; tmp.push(data); data = tmp; } // turn sparse hash input into arrays with 0s as filler var datum = data[0].input; if (!_(datum).isArray() && !(datum instanceof Float64Array)) { if (!this.inputLookup) { this.inputLookup = lookup.buildLookup(_(data).pluck("input")); } data = data.map(function (datum) { var array = lookup.toArray(this.inputLookup, datum.input) return _(_(datum).clone()).extend({input: array}); }, this); } if (!_(data[0].output).isArray()) { if (!this.outputLookup) { this.outputLookup = lookup.buildLookup(_(data).pluck("output")); } data = data.map(function (datum) { var array = lookup.toArray(this.outputLookup, datum.output); return _(_(datum).clone()).extend({output: array}); }, this); } return data; }, test: function (data) { data = this.formatData(data); // for binary classification problems with one output node var isBinary = data[0].output.length == 1; var falsePos = 0, falseNeg = 0, truePos = 0, trueNeg = 0; // for classification problems var misclasses = []; // run each pattern through the trained network and collect // error and misclassification statistics var sum = 0; for (var i = 0; i < data.length; i++) { var output = this.runInput(data[i].input); var target = data[i].output; var actual, expected; if (isBinary) { actual = output[0] > this.binaryThresh ? 1 : 0; expected = target[0]; } else { actual = output.indexOf(_(output).max()); expected = target.indexOf(_(target).max()); } if (actual != expected) { var misclass = data[i]; _(misclass).extend({ actual: actual, expected: expected }) misclasses.push(misclass); } if (isBinary) { if (actual == 0 && expected == 0) { trueNeg++; } else if (actual == 1 && expected == 1) { truePos++; } else if (actual == 0 && expected == 1) { falseNeg++; } else if (actual == 1 && expected == 0) { falsePos++; } } var errors = output.map(function (value, i) { return target[i] - value; }); sum += mse(errors); } var error = sum / data.length; var stats = { error: error, misclasses: misclasses }; if (isBinary) { _(stats).extend({ trueNeg: trueNeg, truePos: truePos, falseNeg: falseNeg, falsePos: falsePos, total: data.length, precision: truePos / (truePos + falsePos), recall: truePos / (truePos + falseNeg), accuracy: (trueNeg + truePos) / data.length }) } return stats; }, toJSON: function () { /* make json look like: { layers: [ { x: {}, y: {}}, {'0': {bias: -0.98771313, weights: {x: 0.8374838, y: 1.245858}, '1': {bias: 3.48192004, weights: {x: 1.7825821, y: -2.67899}}}, { f: {bias: 0.27205739, weights: {'0': 1.3161821, '1': 2.00436}}} ] } */ var layers = []; for (var layer = 0; layer <= this.outputLayer; layer++) { layers[layer] = {}; var nodes; // turn any internal arrays back into hashes for readable json if (layer == 0 && this.inputLookup) { nodes = _(this.inputLookup).keys(); } else if (layer == this.outputLayer && this.outputLookup) { nodes = _(this.outputLookup).keys(); } else { nodes = _.range(0, this.sizes[layer]); } for (var j = 0; j < nodes.length; j++) { var node = nodes[j]; layers[layer][node] = {}; if (layer > 0) { layers[layer][node].bias = this.biases[layer][j]; layers[layer][node].weights = {}; for (var k in layers[layer - 1]) { var index = k; if (layer == 1 && this.inputLookup) { index = this.inputLookup[k]; } layers[layer][node].weights[k] = this.weights[layer][j][index]; } } } } return {layers: layers, outputLookup: !!this.outputLookup, inputLookup: !!this.inputLookup}; }, fromJSON: function (json) { var size = json.layers.length; this.outputLayer = size - 1; this.sizes = new Array(size); this.weights = new Array(size); this.biases = new Array(size); this.outputs = new Array(size); for (var i = 0; i <= this.outputLayer; i++) { var layer = json.layers[i]; if (i == 0 && (!layer[0] || json.inputLookup)) { this.inputLookup = lookup.lookupFromHash(layer); } else if (i == this.outputLayer && (!layer[0] || json.outputLookup)) { this.outputLookup = lookup.lookupFromHash(layer); } var nodes = _(layer).keys(); this.sizes[i] = nodes.length; this.weights[i] = []; this.biases[i] = []; this.outputs[i] = []; for (var j in nodes) { var node = nodes[j]; this.biases[i][j] = layer[node].bias; this.weights[i][j] = _(layer[node].weights).toArray(); } } return this; }, toFunction: function () { var json = this.toJSON(); // return standalone function that mimics run() return new Function("input", ' var net = ' + JSON.stringify(json) + ';\n\n\ for (var i = 1; i < net.layers.length; i++) {\n\ var layer = net.layers[i];\n\ var output = {};\n\ \n\ for (var id in layer) {\n\ var node = layer[id];\n\ var sum = node.bias;\n\ \n\ for (var iid in node.weights) {\n\ sum += node.weights[iid] * input[iid];\n\ }\n\ output[id] = (1 / (1 + Math.exp(-sum)));\n\ }\n\ input = output;\n\ }\n\ return output;'); }, // This will create a TrainStream (WriteStream) // for us to send the training data to. // param: opts - the training options createTrainStream: function (opts) { opts = opts || {}; opts.neuralNetwork = this; this.trainStream = new TrainStream(opts); return this.trainStream; } } function randomWeight() { return Math.random() * 0.4 - 0.2; } function zeros(size) { var array = new Array(size); for (var i = 0; i < size; i++) { array[i] = 0; } return array; } function randos(size) { var array = new Array(size); for (var i = 0; i < size; i++) { array[i] = randomWeight(); } return array; } function mse(errors) { // mean squared error var sum = 0; for (var i = 0; i < errors.length; i++) { sum += Math.pow(errors[i], 2); } return sum / errors.length; } exports.NeuralNetwork = NeuralNetwork; function TrainStream(opts) { Writable.call(this, { objectMode: true }); opts = opts || {}; // require the neuralNetwork if (!opts.neuralNetwork) { throw new Error('no neural network specified'); } this.neuralNetwork = opts.neuralNetwork; this.dataFormatDetermined = false; this.inputKeys = []; this.outputKeys = []; // keeps track of keys seen this.i = 0; // keep track of the for loop i variable that we got rid of this.iterations = opts.iterations || 20000; this.errorThresh = opts.errorThresh || 0.005; this.log = opts.log ? (_.isFunction(opts.log) ? opts.log : console.log) : false; this.logPeriod = opts.logPeriod || 10; this.callback = opts.callback; this.callbackPeriod = opts.callbackPeriod || 10; this.floodCallback = opts.floodCallback; this.doneTrainingCallback = opts.doneTrainingCallback; this.size = 0; this.count = 0; this.sum = 0; this.on('finish', this.finishStreamIteration); return this; } inherits(TrainStream, Writable); /* _write expects data to be in the form of a datum. ie. {input: {a: 1 b: 0}, output: {z: 0}} */ TrainStream.prototype._write = function (chunk, enc, next) { if (!chunk) { // check for the end of one interation of the stream this.emit('finish'); return next(); } if (!this.dataFormatDetermined) { this.size++; this.inputKeys = _.union(this.inputKeys, _.keys(chunk.input)); this.outputKeys = _.union(this.outputKeys, _.keys(chunk.output)); this.firstDatum = this.firstDatum || chunk; return next(); } this.count++; var data = this.neuralNetwork.formatData(chunk); this.trainDatum(data[0]); // tell the Readable Stream that we are ready for more data next(); } TrainStream.prototype.trainDatum = function (datum) { var err = this.neuralNetwork.trainPattern(datum.input, datum.output); this.sum += err; } TrainStream.prototype.finishStreamIteration = function () { if (this.dataFormatDetermined && this.size !== this.count) { console.log("This iteration's data length was different from the first."); } if (!this.dataFormatDetermined) { // create the lookup this.neuralNetwork.inputLookup = lookup.lookupFromArray(this.inputKeys); if (!_.isArray(this.firstDatum.output)) { this.neuralNetwork.outputLookup = lookup.lookupFromArray(this.outputKeys); } var data = this.neuralNetwork.formatData(this.firstDatum); var inputSize = data[0].input.length; var outputSize = data[0].output.length; var hiddenSizes = this.hiddenSizes; if (!hiddenSizes) { hiddenSizes = [Math.max(3, Math.floor(inputSize / 2))]; } var sizes = _([inputSize, hiddenSizes, outputSize]).flatten(); this.dataFormatDetermined = true; this.neuralNetwork.initialize(sizes); if (typeof this.floodCallback === 'function') { this.floodCallback(); } return; } var error = this.sum / this.size; if (this.log && (this.i % this.logPeriod == 0)) { this.log("iterations:", this.i, "training error:", error); } if (this.callback && (this.i % this.callbackPeriod == 0)) { this.callback({ error: error, iterations: this.i }); } this.sum = 0; this.count = 0; // update the iterations this.i++; // do a check here to see if we need the stream again if (this.i < this.iterations && error > this.errorThresh) { if (typeof this.floodCallback === 'function') { return this.floodCallback(); } } else { // done training if (typeof this.doneTrainingCallback === 'function') { return this.doneTrainingCallback({ error: error, iterations: this.i }); } } }