synaptic
Version:
architecture-free neural network library
667 lines (571 loc) • 18.1 kB
JavaScript
//+ Jonas Raoni Soares Silva
//@ http://jsfromhell.com/array/shuffle [v1.0]
function shuffleInplace(o) { //v1.0
for (var j, x, i = o.length; i; j = Math.floor(Math.random() * i), x = o[--i], o[i] = o[j], o[j] = x);
return o;
};
// Built-in cost functions
const cost = {
// Eq. 9
CROSS_ENTROPY: function (target, output) {
var crossentropy = 0;
for (var i in output)
crossentropy -= (target[i] * Math.log(output[i] + 1e-15)) + ((1 - target[i]) * Math.log((1 + 1e-15) - output[i])); // +1e-15 is a tiny push away to avoid Math.log(0)
return crossentropy;
},
MSE: function (target, output) {
var mse = 0;
for (var i = 0; i < output.length; i++)
mse += Math.pow(target[i] - output[i], 2);
return mse / output.length;
},
BINARY: function (target, output) {
var misses = 0;
for (var i = 0; i < output.length; i++)
misses += Math.round(target[i] * 2) != Math.round(output[i] * 2);
return misses;
}
};
export default class Trainer {
static cost = cost;
constructor(network, options) {
options = options || {};
this.network = network;
this.rate = options.rate || .2;
this.iterations = options.iterations || 100000;
this.error = options.error || .005;
this.cost = options.cost || null;
this.crossValidate = options.crossValidate || null;
}
// trains any given set to a network
train(set, options) {
var error = 1;
var iterations = bucketSize = 0;
var abort = false;
var currentRate;
var cost = options && options.cost || this.cost || Trainer.cost.MSE;
var crossValidate = false, testSet, trainSet;
var start = Date.now();
if (options) {
if (options.iterations)
this.iterations = options.iterations;
if (options.error)
this.error = options.error;
if (options.rate)
this.rate = options.rate;
if (options.cost)
this.cost = options.cost;
if (options.schedule)
this.schedule = options.schedule;
if (options.customLog) {
// for backward compatibility with code that used customLog
console.log('Deprecated: use schedule instead of customLog')
this.schedule = options.customLog;
}
if (this.crossValidate || options.crossValidate) {
if (!this.crossValidate) this.crossValidate = {};
crossValidate = true;
if (options.crossValidate.testSize)
this.crossValidate.testSize = options.crossValidate.testSize;
if (options.crossValidate.testError)
this.crossValidate.testError = options.crossValidate.testError;
}
}
currentRate = this.rate;
if (Array.isArray(this.rate)) {
var bucketSize = Math.floor(this.iterations / this.rate.length);
}
if (crossValidate) {
var numTrain = Math.ceil((1 - this.crossValidate.testSize) * set.length);
trainSet = set.slice(0, numTrain);
testSet = set.slice(numTrain);
}
var lastError = 0;
while ((!abort && iterations < this.iterations && error > this.error)) {
if (crossValidate && error <= this.crossValidate.testError) {
break;
}
var currentSetSize = set.length;
error = 0;
iterations++;
if (bucketSize > 0) {
var currentBucket = Math.floor(iterations / bucketSize);
currentRate = this.rate[currentBucket] || currentRate;
}
if (typeof this.rate === 'function') {
currentRate = this.rate(iterations, lastError);
}
if (crossValidate) {
this._trainSet(trainSet, currentRate, cost);
error += this.test(testSet).error;
currentSetSize = 1;
} else {
error += this._trainSet(set, currentRate, cost);
currentSetSize = set.length;
}
// check error
error /= currentSetSize;
lastError = error;
if (options) {
if (this.schedule && this.schedule.every && iterations %
this.schedule.every == 0)
abort = this.schedule.do({error: error, iterations: iterations, rate: currentRate});
else if (options.log && iterations % options.log == 0) {
console.log('iterations', iterations, 'error', error, 'rate', currentRate);
}
;
if (options.shuffle)
shuffleInplace(set);
}
}
var results = {
error: error,
iterations: iterations,
time: Date.now() - start
};
return results;
}
// trains any given set to a network, using a WebWorker (only for the browser). Returns a Promise of the results.
trainAsync(set, options) {
var train = this.workerTrain.bind(this);
return new Promise(function (resolve, reject) {
try {
train(set, resolve, options, true)
} catch (e) {
reject(e)
}
})
}
// preforms one training epoch and returns the error (private function used in this.train)
_trainSet(set, currentRate, costFunction) {
var errorSum = 0;
for (var i = 0; i < set.length; i++) {
var input = set[i].input;
var target = set[i].output;
var output = this.network.activate(input);
this.network.propagate(currentRate, target);
errorSum += costFunction(target, output);
}
return errorSum;
}
// tests a set and returns the error and elapsed time
test(set, options) {
var error = 0;
var input, output, target;
var cost = options && options.cost || this.cost || Trainer.cost.MSE;
var start = Date.now();
for (var i = 0; i < set.length; i++) {
input = set[i].input;
target = set[i].output;
output = this.network.activate(input);
error += cost(target, output);
}
error /= set.length;
var results = {
error: error,
time: Date.now() - start
};
return results;
}
// trains any given set to a network using a WebWorker [deprecated: use trainAsync instead]
workerTrain(set, callback, options, suppressWarning) {
if (!suppressWarning) {
console.warn('Deprecated: do not use `workerTrain`, use `trainAsync` instead.')
}
var that = this;
if (!this.network.optimized)
this.network.optimize();
// Create a new worker
var worker = this.network.worker(this.network.optimized.memory, set, options);
// train the worker
worker.onmessage = function (e) {
switch (e.data.action) {
case 'done':
var iterations = e.data.message.iterations;
var error = e.data.message.error;
var time = e.data.message.time;
that.network.optimized.ownership(e.data.memoryBuffer);
// Done callback
callback({
error: error,
iterations: iterations,
time: time
});
// Delete the worker and all its associated memory
worker.terminate();
break;
case 'log':
console.log(e.data.message);
case 'schedule':
if (options && options.schedule && typeof options.schedule.do === 'function') {
var scheduled = options.schedule.do
scheduled(e.data.message)
}
break;
}
};
// Start the worker
worker.postMessage({action: 'startTraining'});
}
// trains an XOR to the network
XOR(options) {
if (this.network.inputs() != 2 || this.network.outputs() != 1)
throw new Error('Incompatible network (2 inputs, 1 output)');
var defaults = {
iterations: 100000,
log: false,
shuffle: true,
cost: Trainer.cost.MSE
};
if (options)
for (var i in options)
defaults[i] = options[i];
return this.train([{
input: [0, 0],
output: [0]
}, {
input: [1, 0],
output: [1]
}, {
input: [0, 1],
output: [1]
}, {
input: [1, 1],
output: [0]
}], defaults);
}
// trains the network to pass a Distracted Sequence Recall test
DSR(options) {
options = options || {};
var targets = options.targets || [2, 4, 7, 8];
var distractors = options.distractors || [3, 5, 6, 9];
var prompts = options.prompts || [0, 1];
var length = options.length || 24;
var criterion = options.success || 0.95;
var iterations = options.iterations || 100000;
var rate = options.rate || .1;
var log = options.log || 0;
var schedule = options.schedule || {};
var cost = options.cost || this.cost || Trainer.cost.CROSS_ENTROPY;
var trial, correct, i, j, success;
trial = correct = i = j = success = 0;
var error = 1,
symbols = targets.length + distractors.length + prompts.length;
var noRepeat = function (range, avoid) {
var number = Math.random() * range | 0;
var used = false;
for (var i in avoid)
if (number == avoid[i])
used = true;
return used ? noRepeat(range, avoid) : number;
};
var equal = function (prediction, output) {
for (var i in prediction)
if (Math.round(prediction[i]) != output[i])
return false;
return true;
};
var start = Date.now();
while (trial < iterations && (success < criterion || trial % 1000 != 0)) {
// generate sequence
var sequence = [],
sequenceLength = length - prompts.length;
for (i = 0; i < sequenceLength; i++) {
var any = Math.random() * distractors.length | 0;
sequence.push(distractors[any]);
}
var indexes = [],
positions = [];
for (i = 0; i < prompts.length; i++) {
indexes.push(Math.random() * targets.length | 0);
positions.push(noRepeat(sequenceLength, positions));
}
positions = positions.sort();
for (i = 0; i < prompts.length; i++) {
sequence[positions[i]] = targets[indexes[i]];
sequence.push(prompts[i]);
}
//train sequence
var distractorsCorrect;
var targetsCorrect = distractorsCorrect = 0;
error = 0;
for (i = 0; i < length; i++) {
// generate input from sequence
var input = [];
for (j = 0; j < symbols; j++)
input[j] = 0;
input[sequence[i]] = 1;
// generate target output
var output = [];
for (j = 0; j < targets.length; j++)
output[j] = 0;
if (i >= sequenceLength) {
var index = i - sequenceLength;
output[indexes[index]] = 1;
}
// check result
var prediction = this.network.activate(input);
if (equal(prediction, output))
if (i < sequenceLength)
distractorsCorrect++;
else
targetsCorrect++;
else {
this.network.propagate(rate, output);
}
error += cost(output, prediction);
if (distractorsCorrect + targetsCorrect == length)
correct++;
}
// calculate error
if (trial % 1000 == 0)
correct = 0;
trial++;
var divideError = trial % 1000;
divideError = divideError == 0 ? 1000 : divideError;
success = correct / divideError;
error /= length;
// log
if (log && trial % log == 0)
console.log('iterations:', trial, ' success:', success, ' correct:',
correct, ' time:', Date.now() - start, ' error:', error);
if (schedule.do && schedule.every && trial % schedule.every == 0)
schedule.do({
iterations: trial,
success: success,
error: error,
time: Date.now() - start,
correct: correct
});
}
return {
iterations: trial,
success: success,
error: error,
time: Date.now() - start
}
}
// train the network to learn an Embeded Reber Grammar
ERG(options) {
options = options || {};
var iterations = options.iterations || 150000;
var criterion = options.error || .05;
var rate = options.rate || .1;
var log = options.log || 500;
var cost = options.cost || this.cost || Trainer.cost.CROSS_ENTROPY;
// gramar node
var Node = function () {
this.paths = [];
};
Node.prototype = {
connect: function (node, value) {
this.paths.push({
node: node,
value: value
});
return this;
},
any: function () {
if (this.paths.length == 0)
return false;
var index = Math.random() * this.paths.length | 0;
return this.paths[index];
},
test: function (value) {
for (var i in this.paths)
if (this.paths[i].value == value)
return this.paths[i];
return false;
}
};
var reberGrammar = function () {
// build a reber grammar
var output = new Node();
var n1 = (new Node()).connect(output, 'E');
var n2 = (new Node()).connect(n1, 'S');
var n3 = (new Node()).connect(n1, 'V').connect(n2, 'P');
var n4 = (new Node()).connect(n2, 'X');
n4.connect(n4, 'S');
var n5 = (new Node()).connect(n3, 'V');
n5.connect(n5, 'T');
n2.connect(n5, 'X');
var n6 = (new Node()).connect(n4, 'T').connect(n5, 'P');
var input = (new Node()).connect(n6, 'B');
return {
input: input,
output: output
}
};
// build an embeded reber grammar
var embededReberGrammar = function () {
var reber1 = reberGrammar();
var reber2 = reberGrammar();
var output = new Node();
var n1 = (new Node).connect(output, 'E');
reber1.output.connect(n1, 'T');
reber2.output.connect(n1, 'P');
var n2 = (new Node).connect(reber1.input, 'P').connect(reber2.input,
'T');
var input = (new Node).connect(n2, 'B');
return {
input: input,
output: output
}
};
// generate an ERG sequence
var generate = function () {
var node = embededReberGrammar().input;
var next = node.any();
var str = '';
while (next) {
str += next.value;
next = next.node.any();
}
return str;
};
// test if a string matches an embeded reber grammar
var test = function (str) {
var node = embededReberGrammar().input;
var i = 0;
var ch = str.charAt(i);
while (i < str.length) {
var next = node.test(ch);
if (!next)
return false;
node = next.node;
ch = str.charAt(++i);
}
return true;
};
// helper to check if the output and the target vectors match
var different = function (array1, array2) {
var max1 = 0;
var i1 = -1;
var max2 = 0;
var i2 = -1;
for (var i in array1) {
if (array1[i] > max1) {
max1 = array1[i];
i1 = i;
}
if (array2[i] > max2) {
max2 = array2[i];
i2 = i;
}
}
return i1 != i2;
};
var iteration = 0;
var error = 1;
var table = {
'B': 0,
'P': 1,
'T': 2,
'X': 3,
'S': 4,
'E': 5
};
var start = Date.now();
while (iteration < iterations && error > criterion) {
var i = 0;
error = 0;
// ERG sequence to learn
var sequence = generate();
// input
var read = sequence.charAt(i);
// target
var predict = sequence.charAt(i + 1);
// train
while (i < sequence.length - 1) {
var input = [];
var target = [];
for (var j = 0; j < 6; j++) {
input[j] = 0;
target[j] = 0;
}
input[table[read]] = 1;
target[table[predict]] = 1;
var output = this.network.activate(input);
if (different(output, target))
this.network.propagate(rate, target);
read = sequence.charAt(++i);
predict = sequence.charAt(i + 1);
error += cost(target, output);
}
error /= sequence.length;
iteration++;
if (iteration % log == 0) {
console.log('iterations:', iteration, ' time:', Date.now() - start,
' error:', error);
}
}
return {
iterations: iteration,
error: error,
time: Date.now() - start,
test: test,
generate: generate
}
}
timingTask(options) {
if (this.network.inputs() != 2 || this.network.outputs() != 1)
throw new Error('Invalid Network: must have 2 inputs and one output');
if (typeof options == 'undefined')
options = {};
// helper
function getSamples(trainingSize, testSize) {
// sample size
var size = trainingSize + testSize;
// generate samples
var t = 0;
var set = [];
for (var i = 0; i < size; i++) {
set.push({input: [0, 0], output: [0]});
}
while (t < size - 20) {
var n = Math.round(Math.random() * 20);
set[t].input[0] = 1;
for (var j = t; j <= t + n; j++) {
set[j].input[1] = n / 20;
set[j].output[0] = 0.5;
}
t += n;
n = Math.round(Math.random() * 20);
for (var k = t + 1; k <= (t + n) && k < size; k++)
set[k].input[1] = set[t].input[1];
t += n;
}
// separate samples between train and test sets
var trainingSet = [];
var testSet = [];
for (var l = 0; l < size; l++)
(l < trainingSize ? trainingSet : testSet).push(set[l]);
// return samples
return {
train: trainingSet,
test: testSet
}
}
var iterations = options.iterations || 200;
var error = options.error || .005;
var rate = options.rate || [.03, .02];
var log = options.log === false ? false : options.log || 10;
var cost = options.cost || this.cost || Trainer.cost.MSE;
var trainingSamples = options.trainSamples || 7000;
var testSamples = options.trainSamples || 1000;
// samples for training and testing
var samples = getSamples(trainingSamples, testSamples);
// train
var result = this.train(samples.train, {
rate: rate,
log: log,
iterations: iterations,
error: error,
cost: cost
});
return {
train: result,
test: this.test(samples.test)
}
}
}