UNPKG

synaptic

Version:

architecture-free neural network library

609 lines (506 loc) 15.5 kB
// import var chai = require('chai'); chai.use(require('chai-stats')); var assert = chai.assert; import * as synaptic from '../src/synaptic'; var Perceptron = synaptic.Architect.Perceptron; var LSTM = synaptic.Architect.LSTM; var Layer = synaptic.Layer; var Network = synaptic.Network; var Trainer = synaptic.Trainer; var learningRate = .5; // utils function noRepeat(range, avoid) { var number = Math.random() * range | 0; for (var i in avoid) { if (number == avoid[i]) { return noRepeat(range, avoid); } } return number; } function equal(prediction, output) { for (var i in prediction) if (Math.round(prediction[i]) != output[i]) return false; return true; } function generateRandomArray(size) { var array = []; for (var j = 0; j < size; j++) array.push(Math.random() + .5 | 0); return array; } function calculateMse(a, b) { var mse = 0; for (var k in a) mse += Math.pow(a[k] - b[k], 2); mse /= a.length; return mse; } function equalWithError(output, expected, error) { return Math.abs(output - expected) <= error; } // specs describe('Basic Neural Network', function () { it("trains an AND gate", function () { var inputLayer = new Layer(2), outputLayer = new Layer(1); inputLayer.project(outputLayer); var network = new Network({ input: inputLayer, output: outputLayer }); var trainer = new Trainer(network); var trainingSet = [{ input: [0, 0], output: [0] }, { input: [0, 1], output: [0] }, { input: [1, 0], output: [0] }, { input: [1, 1], output: [1] }]; trainer.train(trainingSet, { iterations: 1000, error: .001 }); var test00 = Math.round(network.activate([0, 0])); assert.equal(test00, 0, "[0,0] did not output 0"); var test01 = Math.round(network.activate([0, 1])); assert.equal(test01, 0, "[0,1] did not output 0"); var test10 = Math.round(network.activate([1, 0])); assert.equal(test10, 0, "[1,0] did not output 0"); var test11 = Math.round(network.activate([1, 1])); assert.equal(test11, 1, "[1,1] did not output 1"); }); it("trains an OR gate", function () { var inputLayer = new Layer(2), outputLayer = new Layer(1); inputLayer.project(outputLayer); var network = new Network({ input: inputLayer, output: outputLayer }); var trainer = new Trainer(network); var trainingSet = [{ input: [0, 0], output: [0] }, { input: [0, 1], output: [1] }, { input: [1, 0], output: [1] }, { input: [1, 1], output: [1] }]; trainer.train(trainingSet, { iterations: 1000, error: .001 }); var test00 = Math.round(network.activate([0, 0])); assert.equal(test00, 0, "[0,0] did not output 0"); var test01 = Math.round(network.activate([0, 1])); assert.equal(test01, 1, "[0,1] did not output 1"); var test10 = Math.round(network.activate([1, 0])); assert.equal(test10, 1, "[1,0] did not output 1"); var test11 = Math.round(network.activate([1, 1])); assert.equal(test11, 1, "[1,1] did not output 1"); }); it("trains a NOT gate", function () { var inputLayer = new Layer(1), outputLayer = new Layer(1); inputLayer.project(outputLayer); var network = new Network({ input: inputLayer, output: outputLayer }); var trainer = new Trainer(network); var trainingSet = [{ input: [0], output: [1] }, { input: [1], output: [0] }]; trainer.train(trainingSet, { iterations: 1000, error: .001 }); var test0 = Math.round(network.activate([0])); assert.equal(test0, 1, "0 did not output 1"); var test1 = Math.round(network.activate([1])); assert.equal(test1, 0, "1 did not output 0"); }); }); describe("Perceptron - XOR", function () { var perceptron = new Perceptron(2, 3, 1); var trainer = new Trainer(perceptron); trainer.XOR(); it("should return near-0 value on [0,0]", function () { assert.isAtMost(perceptron.activate([0, 0]), .49, "[0,0] did not output 0"); }); it("should return near-1 value on [0,1]", function () { assert.isAtLeast(perceptron.activate([0, 1]), .51, "[0,1] did not output 1"); }); it("should return near-1 value on [1,0]", function () { assert.isAtLeast(perceptron.activate([1, 0]), .51, "[1,0] did not output 1"); }); it("should return near-0 value on [1,1]", function () { assert.isAtMost(perceptron.activate([1, 1]), .49, "[1,1] did not output 0"); }); }); describe("Perceptron - SIN", function () { var mySin = function (x) { return (Math.sin(x) + 1) / 2; }; var sinNetwork = new Perceptron(1, 12, 1); var trainer = new Trainer(sinNetwork); var trainingSet = []; while (trainingSet.length < 800) { var inputValue = Math.random() * Math.PI * 2; trainingSet.push({ input: [inputValue], output: [mySin(inputValue)] }); } var results = trainer.train(trainingSet, { iterations: 2000, log: false, error: 1e-6, cost: Trainer.cost.MSE, }); [0, .5 * Math.PI, 2] .forEach(function (x) { var y = mySin(x); it("should return value around " + y + " when [" + x + "] is on input", function () { // near scalability: abs(expected-actual) < 0.5 * 10**(-decimal) // 0.5 * Math.pow(10, -.15) => 0.35397289219206896 assert.almostEqual(sinNetwork.activate([x])[0], y, .15); }); }); var errorResult = results.error; it("Sin error: " + errorResult, function () { assert.isAtMost(errorResult, .001, "Sin error not less than or equal to desired error."); }); }); describe("Perceptron - SIN - CrossValidate", function () { var mySin = function (x) { return (Math.sin(x) + 1) / 2; }; var sinNetwork = new Perceptron(1, 12, 1); var trainer = new Trainer(sinNetwork); var trainingSet = Array.apply(null, Array(800)).map(function () { var inputValue = Math.random() * Math.PI * 2; return { input: [inputValue], output: [mySin(inputValue)] }; }); var results = trainer.train(trainingSet, { iterations: 2000, log: false, error: 1e-6, cost: Trainer.cost.MSE, crossValidate: { testSize: .3, testError: 1e-6 } }); var test0 = sinNetwork.activate([0])[0]; var expected0 = mySin(0); it("input: [0] output: " + test0 + ", expected: " + expected0, function () { assert.isAtMost(Math.abs(test0 - expected0), .035, "[0] did not output " + expected0); }); var test05PI = sinNetwork.activate([.5 * Math.PI])[0]; var expected05PI = mySin(.5 * Math.PI); it("input: [0.5*Math.PI] output: " + test05PI + ", expected: " + expected05PI, function () { assert.isAtMost(Math.abs(test05PI - expected05PI), .035, "[0.5*Math.PI] did not output " + expected05PI); }); var test2 = sinNetwork.activate([2])[0]; var expected2 = mySin(2); it("input: [2] output: " + test2 + ", expected: " + expected2, function () { var eq = equalWithError(test2, expected2, .035); assert.equal(eq, true, "[2] did not output " + expected2); }); var errorResult = results.error; it("CrossValidation error: " + errorResult, function () { var lessThanOrEqualError = errorResult <= .001; assert.equal(lessThanOrEqualError, true, "CrossValidation error not less than or equal to desired error."); }); }); describe("LSTM - Discrete Sequence Recall", function () { var targets = [2, 4]; var distractors = [3, 5]; var prompts = [0, 1]; var length = 9; var lstm = new LSTM(5, 3, 2); var trainer = new Trainer(lstm); trainer.DSR({ targets: targets, distractors: distractors, prompts: prompts, length: length, rate: .17, iterations: 250000 }); var symbols = targets.length + distractors.length + prompts.length; var sequence = [], indexes = [], positions = []; var sequenceLength = length - prompts.length; for (i = 0; i < sequenceLength; i++) { var any = Math.random() * distractors.length | 0; sequence.push(distractors[any]); } indexes = [], positions = []; for (i = 0; i < prompts.length; i++) { indexes.push(Math.random() * targets.length | 0); positions.push(noRepeat(sequenceLength, positions)); } positions = positions.sort(); for (i = 0; i < prompts.length; i++) { sequence[positions[i]] = targets[indexes[i]]; sequence.push(prompts[i]); } var check = function (which) { // generate input from sequence var input = []; for (let j = 0; j < symbols; j++) input[j] = 0; input[sequence[which]] = 1; // generate target output var output = []; for (let j = 0; j < targets.length; j++) output[j] = 0; if (which >= sequenceLength) { var index = which - sequenceLength; output[indexes[index]] = 1; } // check result var prediction = lstm.activate(input); return { prediction: prediction, output: output }; }; var value = function (array) { var max = .5; var res = -1; for (var i in array) if (array[i] > max) { max = array[i]; res = i; } return res == -1 ? '-' : targets[res]; }; it("targets: " + targets, function () { assert(true); }); it("distractors: " + distractors, function () { assert(true); }); it("prompts: " + prompts, function () { assert(true); }); it("length: " + length + "\n", function () { assert(true); }); for (var i = 0; i < length; i++) { var test = check(i); it((i + 1) + ") input: " + sequence[i] + " output: " + value(test.prediction), function () { var ok = equal(test.prediction, test.output); assert(ok); }); } }); describe("LSTM - Timing Task", function () { var network = new LSTM(2, 7, 1); var trainer = new Trainer(network); var result = trainer.timingTask({ log: false, trainSamples: 4000, testSamples: 500 }); it("should complete the training in less than 200 iterations", function () { assert(result.train.iterations <= 200); }); it("should pass the test with an error smaller than 0.05", function () { assert(result.test.error < .05); }); }); describe("Optimized and Unoptimized Networks Equivalency", function () { var optimized; var unoptimized; beforeEach(function () { optimized = new LSTM(2, 1, 1); unoptimized = optimized.clone(); unoptimized.setOptimize(false); }); it('should produce the same output for both networks', function () { this.timeout(30000); for (var i = 0; i < 1000; i++) { var input = generateRandomArray(2); var target = generateRandomArray(1); optimized.activate(input); unoptimized.activate(input); optimized.propagate(learningRate, target); unoptimized.propagate(learningRate, target); } var mse = calculateMse(optimized.activate(input), unoptimized.activate(input)); assert.isAtMost(mse, 1e-9, 'output should be same for both networks after ' + i + ' iterations'); }); }); describe("toJSON/fromJSON Networks Equivalency", function () { var original; var imported; beforeEach(function () { original = new LSTM(10, 5, 5); imported = Network.fromJSON(original.toJSON()); }); it('should produce the same output for both networks', function () { this.timeout(30000); for (var i = 0; i < 1000; i++) { var input = generateRandomArray(10); var output1 = original.activate(input); var output2 = imported.activate(input); var target = generateRandomArray(5); // propagate networks original.propagate(learningRate, target); imported.propagate(learningRate, target); assert.isAtMost(calculateMse(output1, output2), 1e-10, 'output should be same for both networks after ' + i + ' iterations'); } }); }); describe("Cloned Networks Equivalency", function () { var original; var cloned; beforeEach(function () { original = new LSTM(10, 5, 5); cloned = Network.fromJSON(original.toJSON()); }); it('should produce the same output for both networks', function () { this.timeout(30000); for (var i = 0; i < 1000; i++) { var input = generateRandomArray(10); var output1 = original.activate(input); var output2 = cloned.activate(input); var target = generateRandomArray(5); // propagate networks original.propagate(learningRate, target); cloned.propagate(learningRate, target); assert.isAtMost(calculateMse(output1, output2), 1e-10, 'output should be same for both networks after ' + i + ' iterations'); } }); }); describe("Scheduled Tasks", function () { var perceptron = new Perceptron(2, 3, 1); var trainer = new Trainer(perceptron); it('should stop training at 3000 iterations', function () { var final_stats = trainer.XOR({ iterations: 3000, rate: 0.000001, error: 0.000001, schedule: { every: 1000, do: function (data) { return data.iterations == 20000; } } }); assert.equal(final_stats.iterations, 3000) }); it('should abort the training at 2000 iterations', function () { var final_stats = trainer.XOR({ iterations: 3000, rate: 0.000001, error: 0.000001, schedule: { every: 1000, do: function (data) { return data.iterations == 2000; } } }); assert.equal(final_stats.iterations, 2000) }); it('should work even if schedule.do() returns no value', function () { var final_stats = trainer.XOR({ iterations: 3000, rate: 0.000001, error: 0.000001, schedule: { every: 1000, do: function (data) {} } }); assert.equal(final_stats.iterations, 3000) }); }); describe("Rate Callback Check", function () { var perceptron = new Perceptron(2, 3, 1); var trainer = new Trainer(perceptron); it('should switch rate from 0.01 to 0.005 after 1000 iterations', function () { var final_stats = trainer.XOR({ iterations: 2000, rate: function (iterations, error) { return iterations < 1000 ? 0.01 : 0.005 }, error: 0.000001, schedule: { every: 1, do: function (data) { switch (data.iterations) { case 1: case 500: case 999: assert.equal(data.rate, 0.01); break; case 1000: case 1500: case 2000: assert.equal(data.rate, 0.005); break; } } } }); }); }); describe("Rate Array Check", function () { var perceptron = new Perceptron(2, 3, 1); var trainer = new Trainer(perceptron); it('should switch rate from 0.01 to 0.005 after 1000 iterations', function () { var final_stats = trainer.XOR({ iterations: 2000, rate: [0.01, 0.005], error: 0.000001, schedule: { every: 1, do: function (data) { switch (data.iterations) { case 1: case 500: case 999: assert.equal(data.rate, 0.01); break; case 1000: case 1500: case 2000: assert.equal(data.rate, 0.005); break; } } } }); }); });