UNPKG

synaptic

Version:

architecture-free neural network library

619 lines (526 loc) 20.6 kB
import Neuron from './Neuron'; import Layer from './Layer'; import Trainer from './Trainer'; export default class Network { constructor(layers) { if (typeof layers != 'undefined') { this.layers = { input: layers.input || null, hidden: layers.hidden || [], output: layers.output || null }; this.optimized = null; } } // feed-forward activation of all the layers to produce an ouput activate(input) { if (this.optimized === false) { this.layers.input.activate(input); for (var i = 0; i < this.layers.hidden.length; i++) this.layers.hidden[i].activate(); return this.layers.output.activate(); } else { if (this.optimized == null) this.optimize(); return this.optimized.activate(input); } } // back-propagate the error thru the network propagate(rate, target) { if (this.optimized === false) { this.layers.output.propagate(rate, target); for (var i = this.layers.hidden.length - 1; i >= 0; i--) this.layers.hidden[i].propagate(rate); } else { if (this.optimized == null) this.optimize(); this.optimized.propagate(rate, target); } } // project a connection to another unit (either a network or a layer) project(unit, type, weights) { if (this.optimized) this.optimized.reset(); if (unit instanceof Network) return this.layers.output.project(unit.layers.input, type, weights); if (unit instanceof Layer) return this.layers.output.project(unit, type, weights); throw new Error('Invalid argument, you can only project connections to LAYERS and NETWORKS!'); } // let this network gate a connection gate(connection, type) { if (this.optimized) this.optimized.reset(); this.layers.output.gate(connection, type); } // clear all elegibility traces and extended elegibility traces (the network forgets its context, but not what was trained) clear() { this.restore(); var inputLayer = this.layers.input, outputLayer = this.layers.output; inputLayer.clear(); for (var i = 0; i < this.layers.hidden.length; i++) { this.layers.hidden[i].clear(); } outputLayer.clear(); if (this.optimized) this.optimized.reset(); } // reset all weights and clear all traces (ends up like a new network) reset() { this.restore(); var inputLayer = this.layers.input, outputLayer = this.layers.output; inputLayer.reset(); for (var i = 0; i < this.layers.hidden.length; i++) { this.layers.hidden[i].reset(); } outputLayer.reset(); if (this.optimized) this.optimized.reset(); } // hardcodes the behaviour of the whole network into a single optimized function optimize() { var that = this; var optimized = {}; var neurons = this.neurons(); for (var i = 0; i < neurons.length; i++) { var neuron = neurons[i].neuron; var layer = neurons[i].layer; while (neuron.neuron) neuron = neuron.neuron; optimized = neuron.optimize(optimized, layer); } for (var i = 0; i < optimized.propagation_sentences.length; i++) optimized.propagation_sentences[i].reverse(); optimized.propagation_sentences.reverse(); var hardcode = ''; hardcode += 'var F = Float64Array ? new Float64Array(' + optimized.memory + ') : []; '; for (var i in optimized.variables) hardcode += 'F[' + optimized.variables[i].id + '] = ' + (optimized.variables[ i].value || 0) + '; '; hardcode += 'var activate = function(input){\n'; for (var i = 0; i < optimized.inputs.length; i++) hardcode += 'F[' + optimized.inputs[i] + '] = input[' + i + ']; '; for (var i = 0; i < optimized.activation_sentences.length; i++) { if (optimized.activation_sentences[i].length > 0) { for (var j = 0; j < optimized.activation_sentences[i].length; j++) { hardcode += optimized.activation_sentences[i][j].join(' '); hardcode += optimized.trace_sentences[i][j].join(' '); } } } hardcode += ' var output = []; ' for (var i = 0; i < optimized.outputs.length; i++) hardcode += 'output[' + i + '] = F[' + optimized.outputs[i] + ']; '; hardcode += 'return output; }; ' hardcode += 'var propagate = function(rate, target){\n'; hardcode += 'F[' + optimized.variables.rate.id + '] = rate; '; for (var i = 0; i < optimized.targets.length; i++) hardcode += 'F[' + optimized.targets[i] + '] = target[' + i + ']; '; for (var i = 0; i < optimized.propagation_sentences.length; i++) for (var j = 0; j < optimized.propagation_sentences[i].length; j++) hardcode += optimized.propagation_sentences[i][j].join(' ') + ' '; hardcode += ' };\n'; hardcode += 'var ownership = function(memoryBuffer){\nF = memoryBuffer;\nthis.memory = F;\n};\n'; hardcode += 'return {\nmemory: F,\nactivate: activate,\npropagate: propagate,\nownership: ownership\n};'; hardcode = hardcode.split(';').join(';\n'); var constructor = new Function(hardcode); var network = constructor(); network.data = { variables: optimized.variables, activate: optimized.activation_sentences, propagate: optimized.propagation_sentences, trace: optimized.trace_sentences, inputs: optimized.inputs, outputs: optimized.outputs, check_activation: this.activate, check_propagation: this.propagate } network.reset = function () { if (that.optimized) { that.optimized = null; that.activate = network.data.check_activation; that.propagate = network.data.check_propagation; } } this.optimized = network; this.activate = network.activate; this.propagate = network.propagate; } // restores all the values from the optimized network the their respective objects in order to manipulate the network restore() { if (!this.optimized) return; var optimized = this.optimized; var getValue = function () { var args = Array.prototype.slice.call(arguments); var unit = args.shift(); var prop = args.pop(); var id = prop + '_'; for (var property in args) id += args[property] + '_'; id += unit.ID; var memory = optimized.memory; var variables = optimized.data.variables; if (id in variables) return memory[variables[id].id]; return 0; } var list = this.neurons(); // link id's to positions in the array for (var i = 0; i < list.length; i++) { var neuron = list[i].neuron; while (neuron.neuron) neuron = neuron.neuron; neuron.state = getValue(neuron, 'state'); neuron.old = getValue(neuron, 'old'); neuron.activation = getValue(neuron, 'activation'); neuron.bias = getValue(neuron, 'bias'); for (var input in neuron.trace.elegibility) neuron.trace.elegibility[input] = getValue(neuron, 'trace', 'elegibility', input); for (var gated in neuron.trace.extended) for (var input in neuron.trace.extended[gated]) neuron.trace.extended[gated][input] = getValue(neuron, 'trace', 'extended', gated, input); // get connections for (var j in neuron.connections.projected) { var connection = neuron.connections.projected[j]; connection.weight = getValue(connection, 'weight'); connection.gain = getValue(connection, 'gain'); } } } // returns all the neurons in the network neurons() { var neurons = []; var inputLayer = this.layers.input.neurons(), outputLayer = this.layers.output.neurons(); for (var i = 0; i < inputLayer.length; i++) { neurons.push({ neuron: inputLayer[i], layer: 'input' }); } for (var i = 0; i < this.layers.hidden.length; i++) { var hiddenLayer = this.layers.hidden[i].neurons(); for (var j = 0; j < hiddenLayer.length; j++) neurons.push({ neuron: hiddenLayer[j], layer: i }); } for (var i = 0; i < outputLayer.length; i++) { neurons.push({ neuron: outputLayer[i], layer: 'output' }); } return neurons; } // returns number of inputs of the network inputs() { return this.layers.input.size; } // returns number of outputs of hte network outputs() { return this.layers.output.size; } // sets the layers of the network set(layers) { this.layers = { input: layers.input || null, hidden: layers.hidden || [], output: layers.output || null }; if (this.optimized) this.optimized.reset(); } setOptimize(bool) { this.restore(); if (this.optimized) this.optimized.reset(); this.optimized = bool ? null : false; } // returns a json that represents all the neurons and connections of the network toJSON(ignoreTraces) { this.restore(); var list = this.neurons(); var neurons = []; var connections = []; // link id's to positions in the array var ids = {}; for (var i = 0; i < list.length; i++) { var neuron = list[i].neuron; while (neuron.neuron) neuron = neuron.neuron; ids[neuron.ID] = i; var copy = { trace: { elegibility: {}, extended: {} }, state: neuron.state, old: neuron.old, activation: neuron.activation, bias: neuron.bias, layer: list[i].layer }; copy.squash = neuron.squash == Neuron.squash.LOGISTIC ? 'LOGISTIC' : neuron.squash == Neuron.squash.TANH ? 'TANH' : neuron.squash == Neuron.squash.IDENTITY ? 'IDENTITY' : neuron.squash == Neuron.squash.HLIM ? 'HLIM' : neuron.squash == Neuron.squash.RELU ? 'RELU' : null; neurons.push(copy); } for (var i = 0; i < list.length; i++) { var neuron = list[i].neuron; while (neuron.neuron) neuron = neuron.neuron; for (var j in neuron.connections.projected) { var connection = neuron.connections.projected[j]; connections.push({ from: ids[connection.from.ID], to: ids[connection.to.ID], weight: connection.weight, gater: connection.gater ? ids[connection.gater.ID] : null, }); } if (neuron.selfconnected()) { connections.push({ from: ids[neuron.ID], to: ids[neuron.ID], weight: neuron.selfconnection.weight, gater: neuron.selfconnection.gater ? ids[neuron.selfconnection.gater.ID] : null, }); } } return { neurons: neurons, connections: connections } } // export the topology into dot language which can be visualized as graphs using dot /* example: ... console.log(net.toDotLang()); $ node example.js > example.dot $ dot example.dot -Tpng > out.png */ toDot(edgeConnection) { if (!typeof edgeConnection) edgeConnection = false; var code = 'digraph nn {\n rankdir = BT\n'; var layers = [this.layers.input].concat(this.layers.hidden, this.layers.output); for (var i = 0; i < layers.length; i++) { for (var j = 0; j < layers[i].connectedTo.length; j++) { // projections var connection = layers[i].connectedTo[j]; var layerTo = connection.to; var size = connection.size; var layerID = layers.indexOf(layers[i]); var layerToID = layers.indexOf(layerTo); /* http://stackoverflow.com/questions/26845540/connect-edges-with-graph-dot * DOT does not support edge-to-edge connections * This workaround produces somewhat weird graphs ... */ if (edgeConnection) { if (connection.gatedfrom.length) { var fakeNode = 'fake' + layerID + '_' + layerToID; code += ' ' + fakeNode + ' [label = "", shape = point, width = 0.01, height = 0.01]\n'; code += ' ' + layerID + ' -> ' + fakeNode + ' [label = ' + size + ', arrowhead = none]\n'; code += ' ' + fakeNode + ' -> ' + layerToID + '\n'; } else code += ' ' + layerID + ' -> ' + layerToID + ' [label = ' + size + ']\n'; for (var from in connection.gatedfrom) { // gatings var layerfrom = connection.gatedfrom[from].layer; var layerfromID = layers.indexOf(layerfrom); code += ' ' + layerfromID + ' -> ' + fakeNode + ' [color = blue]\n'; } } else { code += ' ' + layerID + ' -> ' + layerToID + ' [label = ' + size + ']\n'; for (var from in connection.gatedfrom) { // gatings var layerfrom = connection.gatedfrom[from].layer; var layerfromID = layers.indexOf(layerfrom); code += ' ' + layerfromID + ' -> ' + layerToID + ' [color = blue]\n'; } } } } code += '}\n'; return { code: code, link: 'https://chart.googleapis.com/chart?chl=' + escape(code.replace('/ /g', '+')) + '&cht=gv' } } // returns a function that works as the activation of the network and can be used without depending on the library standalone() { if (!this.optimized) this.optimize(); var data = this.optimized.data; // build activation function var activation = 'function (input) {\n'; // build inputs for (var i = 0; i < data.inputs; i++) activation += 'F[' + data.inputs[i] + '] = input[' + i + '];\n'; // build network activation for (var i = 0; i < data.activate.length; i++) { // shouldn't this be layer? for (var j = 0; j < data.activate[i].length; j++) activation += data.activate[i][j].join('') + '\n'; } // build outputs activation += 'var output = [];\n'; for (var i = 0; i < data.outputs.length; i++) activation += 'output[' + i + '] = F[' + data.outputs[i] + '];\n'; activation += 'return output;\n}'; // reference all the positions in memory var memory = activation.match(/F\[(\d+)\]/g); var dimension = 0; var ids = {}; for (var i = 0; i < memory.length; i++) { var tmp = memory[i].match(/\d+/)[0]; if (!(tmp in ids)) { ids[tmp] = dimension++; } } var hardcode = 'F = {\n'; for (var i in ids) hardcode += ids[i] + ': ' + this.optimized.memory[i] + ',\n'; hardcode = hardcode.substring(0, hardcode.length - 2) + '\n};\n'; hardcode = 'var run = ' + activation.replace(/F\[(\d+)]/g, function (index) { return 'F[' + ids[index.match(/\d+/)[0]] + ']' }).replace('{\n', '{\n' + hardcode + '') + ';\n'; hardcode += 'return run'; // return standalone function return new Function(hardcode)(); } // Return a HTML5 WebWorker specialized on training the network stored in `memory`. // Train based on the given dataSet and options. // The worker returns the updated `memory` when done. worker(memory, set, options) { // Copy the options and set defaults (options might be different for each worker) var workerOptions = {}; if (options) workerOptions = options; workerOptions.rate = options.rate || .2; workerOptions.iterations = options.iterations || 100000; workerOptions.error = options.error || .005; workerOptions.cost = options.cost || null; workerOptions.crossValidate = options.crossValidate || null; // Cost function might be different for each worker costFunction = 'var cost = ' + (options && options.cost || this.cost || Trainer.cost.MSE) + ';\n'; var workerFunction = Network.getWorkerSharedFunctions(); workerFunction = workerFunction.replace(/var cost = options && options\.cost \|\| this\.cost \|\| Trainer\.cost\.MSE;/g, costFunction); // Set what we do when training is finished workerFunction = workerFunction.replace('return results;', 'postMessage({action: "done", message: results, memoryBuffer: F}, [F.buffer]);'); // Replace log with postmessage workerFunction = workerFunction.replace('console.log(\'iterations\', iterations, \'error\', error, \'rate\', currentRate)', 'postMessage({action: \'log\', message: {\n' + 'iterations: iterations,\n' + 'error: error,\n' + 'rate: currentRate\n' + '}\n' + '})'); // Replace schedule with postmessage workerFunction = workerFunction.replace('abort = this.schedule.do({ error: error, iterations: iterations, rate: currentRate })', 'postMessage({action: \'schedule\', message: {\n' + 'iterations: iterations,\n' + 'error: error,\n' + 'rate: currentRate\n' + '}\n' + '})'); if (!this.optimized) this.optimize(); var hardcode = 'var inputs = ' + this.optimized.data.inputs.length + ';\n'; hardcode += 'var outputs = ' + this.optimized.data.outputs.length + ';\n'; hardcode += 'var F = new Float64Array([' + this.optimized.memory.toString() + ']);\n'; hardcode += 'var activate = ' + this.optimized.activate.toString() + ';\n'; hardcode += 'var propagate = ' + this.optimized.propagate.toString() + ';\n'; hardcode += 'onmessage = function(e) {\n' + 'if (e.data.action == \'startTraining\') {\n' + 'train(' + JSON.stringify(set) + ',' + JSON.stringify(workerOptions) + ');\n' + '}\n' + '}'; var workerSourceCode = workerFunction + '\n' + hardcode; var blob = new Blob([workerSourceCode]); var blobURL = window.URL.createObjectURL(blob); return new Worker(blobURL); } // returns a copy of the network clone() { return Network.fromJSON(this.toJSON()); } /** * Creates a static String to store the source code of the functions * that are identical for all the workers (train, _trainSet, test) * * @return {String} Source code that can train a network inside a worker. * @static */ static getWorkerSharedFunctions() { // If we already computed the source code for the shared functions if(typeof Network._SHARED_WORKER_FUNCTIONS !== 'undefined') return Network._SHARED_WORKER_FUNCTIONS; // Otherwise compute and return the source code // We compute them by simply copying the source code of the train, _trainSet and test functions // using the .toString() method // Load and name the train function var train_f = Trainer.prototype.train.toString(); train_f = train_f.replace('function (set', 'function train(set') + '\n'; // Load and name the _trainSet function var _trainSet_f = Trainer.prototype._trainSet.toString().replace(/this.network./g, ''); _trainSet_f = _trainSet_f.replace('function (set', 'function _trainSet(set') + '\n'; _trainSet_f = _trainSet_f.replace('this.crossValidate', 'crossValidate'); _trainSet_f = _trainSet_f.replace('crossValidate = true', 'crossValidate = { }'); // Load and name the test function var test_f = Trainer.prototype.test.toString().replace(/this.network./g, ''); test_f = test_f.replace('function (set', 'function test(set') + '\n'; return Network._SHARED_WORKER_FUNCTIONS = train_f + _trainSet_f + test_f; }; static fromJSON(json) { var neurons = []; var layers = { input: new Layer(), hidden: [], output: new Layer() }; for (var i = 0; i < json.neurons.length; i++) { var config = json.neurons[i]; var neuron = new Neuron(); neuron.trace.elegibility = {}; neuron.trace.extended = {}; neuron.state = config.state; neuron.old = config.old; neuron.activation = config.activation; neuron.bias = config.bias; neuron.squash = config.squash in Neuron.squash ? Neuron.squash[config.squash] : Neuron.squash.LOGISTIC; neurons.push(neuron); if (config.layer == 'input') layers.input.add(neuron); else if (config.layer == 'output') layers.output.add(neuron); else { if (typeof layers.hidden[config.layer] == 'undefined') layers.hidden[config.layer] = new Layer(); layers.hidden[config.layer].add(neuron); } } for (var i = 0; i < json.connections.length; i++) { var config = json.connections[i]; var from = neurons[config.from]; var to = neurons[config.to]; var weight = config.weight; var gater = neurons[config.gater]; var connection = from.project(to, weight); if (gater) gater.gate(connection); } return new Network(layers); }; }