UNPKG

neataptic

Version:

Architecture-free neural network library with genetic algorithm implementations

1,342 lines (1,131 loc) 42.4 kB
/* Export */ module.exports = Network; /* Import */ var multi = require('../multithreading/multi'); var methods = require('../methods/methods'); var Connection = require('./connection'); var config = require('../config'); var Neat = require('../neat'); var Node = require('./node'); /* Easier variable naming */ var mutation = methods.mutation; /******************************************************************************* NETWORK *******************************************************************************/ function Network (input, output) { if (typeof input === 'undefined' || typeof output === 'undefined') { throw new Error('No input or output size given'); } this.input = input; this.output = output; // Store all the node and connection genes this.nodes = []; // Stored in activation order this.connections = []; this.gates = []; this.selfconns = []; // Regularization this.dropout = 0; // Create input and output nodes var i; for (i = 0; i < this.input + this.output; i++) { var type = i < this.input ? 'input' : 'output'; this.nodes.push(new Node(type)); } // Connect input nodes with output nodes directly for (i = 0; i < this.input; i++) { for (var j = this.input; j < this.output + this.input; j++) { // https://stats.stackexchange.com/a/248040/147931 var weight = Math.random() * this.input * Math.sqrt(2 / this.input); this.connect(this.nodes[i], this.nodes[j], weight); } } } Network.prototype = { /** * Activates the network */ activate: function (input, training) { var output = []; // Activate nodes chronologically for (var i = 0; i < this.nodes.length; i++) { if (this.nodes[i].type === 'input') { this.nodes[i].activate(input[i]); } else if (this.nodes[i].type === 'output') { var activation = this.nodes[i].activate(); output.push(activation); } else { if (training) this.nodes[i].mask = Math.random() < this.dropout ? 0 : 1; this.nodes[i].activate(); } } return output; }, /** * Activates the network without calculating elegibility traces and such */ noTraceActivate: function (input) { var output = []; // Activate nodes chronologically for (var i = 0; i < this.nodes.length; i++) { if (this.nodes[i].type === 'input') { this.nodes[i].noTraceActivate(input[i]); } else if (this.nodes[i].type === 'output') { var activation = this.nodes[i].noTraceActivate(); output.push(activation); } else { this.nodes[i].noTraceActivate(); } } return output; }, /** * Backpropagate the network */ propagate: function (rate, momentum, update, target) { if (typeof target === 'undefined' || target.length !== this.output) { throw new Error('Output target length should match network output length'); } var targetIndex = target.length; // Propagate output nodes var i; for (i = this.nodes.length - 1; i >= this.nodes.length - this.output; i--) { this.nodes[i].propagate(rate, momentum, update, target[--targetIndex]); } // Propagate hidden and input nodes for (i = this.nodes.length - this.output - 1; i >= this.input; i--) { this.nodes[i].propagate(rate, momentum, update); } }, /** * Clear the context of the network */ clear: function () { for (var i = 0; i < this.nodes.length; i++) { this.nodes[i].clear(); } }, /** * Connects the from node to the to node */ connect: function (from, to, weight) { var connections = from.connect(to, weight); for (var i = 0; i < connections.length; i++) { var connection = connections[i]; if (from !== to) { this.connections.push(connection); } else { this.selfconns.push(connection); } } return connections; }, /** * Disconnects the from node from the to node */ disconnect: function (from, to) { // Delete the connection in the network's connection array var connections = from === to ? this.selfconns : this.connections; for (var i = 0; i < connections.length; i++) { var connection = connections[i]; if (connection.from === from && connection.to === to) { if (connection.gater !== null) this.ungate(connection); connections.splice(i, 1); break; } } // Delete the connection at the sending and receiving neuron from.disconnect(to); }, /** * Gate a connection with a node */ gate: function (node, connection) { if (this.nodes.indexOf(node) === -1) { throw new Error('This node is not part of the network!'); } else if (connection.gater != null) { if (config.warnings) console.warn('This connection is already gated!'); return; } node.gate(connection); this.gates.push(connection); }, /** * Remove the gate of a connection */ ungate: function (connection) { var index = this.gates.indexOf(connection); if (index === -1) { throw new Error('This connection is not gated!'); } this.gates.splice(index, 1); connection.gater.ungate(connection); }, /** * Removes a node from the network */ remove: function (node) { var index = this.nodes.indexOf(node); if (index === -1) { throw new Error('This node does not exist in the network!'); } // Keep track of gaters var gaters = []; // Remove selfconnections from this.selfconns this.disconnect(node, node); // Get all its inputting nodes var inputs = []; for (var i = node.connections.in.length - 1; i >= 0; i--) { let connection = node.connections.in[i]; if (mutation.SUB_NODE.keep_gates && connection.gater !== null && connection.gater !== node) { gaters.push(connection.gater); } inputs.push(connection.from); this.disconnect(connection.from, node); } // Get all its outputing nodes var outputs = []; for (i = node.connections.out.length - 1; i >= 0; i--) { let connection = node.connections.out[i]; if (mutation.SUB_NODE.keep_gates && connection.gater !== null && connection.gater !== node) { gaters.push(connection.gater); } outputs.push(connection.to); this.disconnect(node, connection.to); } // Connect the input nodes to the output nodes (if not already connected) var connections = []; for (i = 0; i < inputs.length; i++) { let input = inputs[i]; for (var j = 0; j < outputs.length; j++) { let output = outputs[j]; if (!input.isProjectingTo(output)) { var conn = this.connect(input, output); connections.push(conn[0]); } } } // Gate random connections with gaters for (i = 0; i < gaters.length; i++) { if (connections.length === 0) break; let gater = gaters[i]; let connIndex = Math.floor(Math.random() * connections.length); this.gate(gater, connections[connIndex]); connections.splice(connIndex, 1); } // Remove gated connections gated by this node for (i = node.connections.gated.length - 1; i >= 0; i--) { let conn = node.connections.gated[i]; this.ungate(conn); } // Remove selfconnection this.disconnect(node, node); // Remove the node from this.nodes this.nodes.splice(index, 1); }, /** * Mutates the network with the given method */ mutate: function (method) { if (typeof method === 'undefined') { throw new Error('No (correct) mutate method given!'); } var i, j; switch (method) { case mutation.ADD_NODE: // Look for an existing connection and place a node in between var connection = this.connections[Math.floor(Math.random() * this.connections.length)]; var gater = connection.gater; this.disconnect(connection.from, connection.to); // Insert the new node right before the old connection.to var toIndex = this.nodes.indexOf(connection.to); var node = new Node('hidden'); // Random squash function node.mutate(mutation.MOD_ACTIVATION); // Place it in this.nodes var minBound = Math.min(toIndex, this.nodes.length - this.output); this.nodes.splice(minBound, 0, node); // Now create two new connections var newConn1 = this.connect(connection.from, node)[0]; var newConn2 = this.connect(node, connection.to)[0]; // Check if the original connection was gated if (gater != null) { this.gate(gater, Math.random() >= 0.5 ? newConn1 : newConn2); } break; case mutation.SUB_NODE: // Check if there are nodes left to remove if (this.nodes.length === this.input + this.output) { if (config.warnings) console.warn('No more nodes left to remove!'); break; } // Select a node which isn't an input or output node var index = Math.floor(Math.random() * (this.nodes.length - this.output - this.input) + this.input); this.remove(this.nodes[index]); break; case mutation.ADD_CONN: // Create an array of all uncreated (feedforward) connections var available = []; for (i = 0; i < this.nodes.length - this.output; i++) { let node1 = this.nodes[i]; for (j = Math.max(i + 1, this.input); j < this.nodes.length; j++) { let node2 = this.nodes[j]; if (!node1.isProjectingTo(node2)) available.push([node1, node2]); } } if (available.length === 0) { if (config.warnings) console.warn('No more connections to be made!'); break; } var pair = available[Math.floor(Math.random() * available.length)]; this.connect(pair[0], pair[1]); break; case mutation.SUB_CONN: // List of possible connections that can be removed var possible = []; for (i = 0; i < this.connections.length; i++) { let conn = this.connections[i]; // Check if it is not disabling a node if (conn.from.connections.out.length > 1 && conn.to.connections.in.length > 1 && this.nodes.indexOf(conn.to) > this.nodes.indexOf(conn.from)) { possible.push(conn); } } if (possible.length === 0) { if (config.warnings) console.warn('No connections to remove!'); break; } var randomConn = possible[Math.floor(Math.random() * possible.length)]; this.disconnect(randomConn.from, randomConn.to); break; case mutation.MOD_WEIGHT: var allconnections = this.connections.concat(this.selfconns); var connection = allconnections[Math.floor(Math.random() * allconnections.length)]; var modification = Math.random() * (method.max - method.min) + method.min; connection.weight += modification; break; case mutation.MOD_BIAS: // Has no effect on input node, so they are excluded var index = Math.floor(Math.random() * (this.nodes.length - this.input) + this.input); var node = this.nodes[index]; node.mutate(method); break; case mutation.MOD_ACTIVATION: // Has no effect on input node, so they are excluded if (!method.mutateOutput && this.input + this.output === this.nodes.length) { if (config.warnings) console.warn('No nodes that allow mutation of activation function'); break; } var index = Math.floor(Math.random() * (this.nodes.length - (method.mutateOutput ? 0 : this.output) - this.input) + this.input); var node = this.nodes[index]; node.mutate(method); break; case mutation.ADD_SELF_CONN: // Check which nodes aren't selfconnected yet var possible = []; for (i = this.input; i < this.nodes.length; i++) { let node = this.nodes[i]; if (node.connections.self.weight === 0) { possible.push(node); } } if (possible.length === 0) { if (config.warnings) console.warn('No more self-connections to add!'); break; } // Select a random node var node = possible[Math.floor(Math.random() * possible.length)]; // Connect it to himself this.connect(node, node); break; case mutation.SUB_SELF_CONN: if (this.selfconns.length === 0) { if (config.warnings) console.warn('No more self-connections to remove!'); break; } var conn = this.selfconns[Math.floor(Math.random() * this.selfconns.length)]; this.disconnect(conn.from, conn.to); break; case mutation.ADD_GATE: var allconnections = this.connections.concat(this.selfconns); // Create a list of all non-gated connections var possible = []; for (i = 0; i < allconnections.length; i++) { let conn = allconnections[i]; if (conn.gater === null) { possible.push(conn); } } if (possible.length === 0) { if (config.warnings) console.warn('No more connections to gate!'); break; } // Select a random gater node and connection, can't be gated by input var index = Math.floor(Math.random() * (this.nodes.length - this.input) + this.input); var node = this.nodes[index]; var conn = possible[Math.floor(Math.random() * possible.length)]; // Gate the connection with the node this.gate(node, conn); break; case mutation.SUB_GATE: // Select a random gated connection if (this.gates.length === 0) { if (config.warnings) console.warn('No more connections to ungate!'); break; } var index = Math.floor(Math.random() * this.gates.length); var gatedconn = this.gates[index]; this.ungate(gatedconn); break; case mutation.ADD_BACK_CONN: // Create an array of all uncreated (backfed) connections var available = []; for (i = this.input; i < this.nodes.length; i++) { let node1 = this.nodes[i]; for (j = this.input; j < i; j++) { let node2 = this.nodes[j]; if (!node1.isProjectingTo(node2)) available.push([node1, node2]); } } if (available.length === 0) { if (config.warnings) console.warn('No more connections to be made!'); break; } var pair = available[Math.floor(Math.random() * available.length)]; this.connect(pair[0], pair[1]); break; case mutation.SUB_BACK_CONN: // List of possible connections that can be removed var possible = []; for (i = 0; i < this.connections.length; i++) { let conn = this.connections[i]; // Check if it is not disabling a node if (conn.from.connections.out.length > 1 && conn.to.connections.in.length > 1 && this.nodes.indexOf(conn.from) > this.nodes.indexOf(conn.to)) { possible.push(conn); } } if (possible.length === 0) { if (config.warnings) console.warn('No connections to remove!'); break; } var randomConn = possible[Math.floor(Math.random() * possible.length)]; this.disconnect(randomConn.from, randomConn.to); break; case mutation.SWAP_NODES: // Has no effect on input node, so they are excluded if ((method.mutateOutput && this.nodes.length - this.input < 2) || (!method.mutateOutput && this.nodes.length - this.input - this.output < 2)) { if (config.warnings) console.warn('No nodes that allow swapping of bias and activation function'); break; } var index = Math.floor(Math.random() * (this.nodes.length - (method.mutateOutput ? 0 : this.output) - this.input) + this.input); var node1 = this.nodes[index]; index = Math.floor(Math.random() * (this.nodes.length - (method.mutateOutput ? 0 : this.output) - this.input) + this.input); var node2 = this.nodes[index]; var biasTemp = node1.bias; var squashTemp = node1.squash; node1.bias = node2.bias; node1.squash = node2.squash; node2.bias = biasTemp; node2.squash = squashTemp; break; } }, /** * Train the given set to this network */ train: function (set, options) { if (set[0].input.length !== this.input || set[0].output.length !== this.output) { throw new Error('Dataset input/output size should be same as network input/output size!'); } options = options || {}; // Warning messages if (typeof options.rate === 'undefined') { if (config.warnings) console.warn('Using default learning rate, please define a rate!'); } if (typeof options.iterations === 'undefined') { if (config.warnings) console.warn('No target iterations given, running until error is reached!'); } // Read the options var targetError = options.error || 0.05; var cost = options.cost || methods.cost.MSE; var baseRate = options.rate || 0.3; var dropout = options.dropout || 0; var momentum = options.momentum || 0; var batchSize = options.batchSize || 1; // online learning var ratePolicy = options.ratePolicy || methods.rate.FIXED(); var start = Date.now(); if (batchSize > set.length) { throw new Error('Batch size must be smaller or equal to dataset length!'); } else if (typeof options.iterations === 'undefined' && typeof options.error === 'undefined') { throw new Error('At least one of the following options must be specified: error, iterations'); } else if (typeof options.error === 'undefined') { targetError = -1; // run until iterations } else if (typeof options.iterations === 'undefined') { options.iterations = 0; // run until target error } // Save to network this.dropout = dropout; if (options.crossValidate) { let numTrain = Math.ceil((1 - options.crossValidate.testSize) * set.length); var trainSet = set.slice(0, numTrain); var testSet = set.slice(numTrain); } // Loops the training process var currentRate = baseRate; var iteration = 0; var error = 1; var i, j, x; while (error > targetError && (options.iterations === 0 || iteration < options.iterations)) { if (options.crossValidate && error <= options.crossValidate.testError) break; iteration++; // Update the rate currentRate = ratePolicy(baseRate, iteration); // Checks if cross validation is enabled if (options.crossValidate) { this._trainSet(trainSet, batchSize, currentRate, momentum, cost); if (options.clear) this.clear(); error = this.test(testSet, cost).error; if (options.clear) this.clear(); } else { error = this._trainSet(set, batchSize, currentRate, momentum, cost); if (options.clear) this.clear(); } // Checks for options such as scheduled logs and shuffling if (options.shuffle) { for (j, x, i = set.length; i; j = Math.floor(Math.random() * i), x = set[--i], set[i] = set[j], set[j] = x); } if (options.log && iteration % options.log === 0) { console.log('iteration', iteration, 'error', error, 'rate', currentRate); } if (options.schedule && iteration % options.schedule.iterations === 0) { options.schedule.function({ error: error, iteration: iteration }); } } if (options.clear) this.clear(); if (dropout) { for (i = 0; i < this.nodes.length; i++) { if (this.nodes[i].type === 'hidden' || this.nodes[i].type === 'constant') { this.nodes[i].mask = 1 - this.dropout; } } } return { error: error, iterations: iteration, time: Date.now() - start }; }, /** * Performs one training epoch and returns the error * private function used in this.train */ _trainSet: function (set, batchSize, currentRate, momentum, costFunction) { var errorSum = 0; for (var i = 0; i < set.length; i++) { var input = set[i].input; var target = set[i].output; var update = !!((i + 1) % batchSize === 0 || (i + 1) === set.length); var output = this.activate(input, true); this.propagate(currentRate, momentum, update, target); errorSum += costFunction(target, output); } return errorSum / set.length; }, /** * Tests a set and returns the error and elapsed time */ test: function (set, cost = methods.cost.MSE) { // Check if dropout is enabled, set correct mask var i; if (this.dropout) { for (i = 0; i < this.nodes.length; i++) { if (this.nodes[i].type === 'hidden' || this.nodes[i].type === 'constant') { this.nodes[i].mask = 1 - this.dropout; } } } var error = 0; var start = Date.now(); for (i = 0; i < set.length; i++) { let input = set[i].input; let target = set[i].output; let output = this.noTraceActivate(input); error += cost(target, output); } error /= set.length; var results = { error: error, time: Date.now() - start }; return results; }, /** * Creates a json that can be used to create a graph with d3 and webcola */ graph: function (width, height) { var input = 0; var output = 0; var json = { nodes: [], links: [], constraints: [{ type: 'alignment', axis: 'x', offsets: [] }, { type: 'alignment', axis: 'y', offsets: [] }] }; var i; for (i = 0; i < this.nodes.length; i++) { var node = this.nodes[i]; if (node.type === 'input') { if (this.input === 1) { json.constraints[0].offsets.push({ node: i, offset: 0 }); } else { json.constraints[0].offsets.push({ node: i, offset: 0.8 * width / (this.input - 1) * input++ }); } json.constraints[1].offsets.push({ node: i, offset: 0 }); } else if (node.type === 'output') { if (this.output === 1) { json.constraints[0].offsets.push({ node: i, offset: 0 }); } else { json.constraints[0].offsets.push({ node: i, offset: 0.8 * width / (this.output - 1) * output++ }); } json.constraints[1].offsets.push({ node: i, offset: -0.8 * height }); } json.nodes.push({ id: i, name: node.type === 'hidden' ? node.squash.name : node.type.toUpperCase(), activation: node.activation, bias: node.bias }); } var connections = this.connections.concat(this.selfconns); for (i = 0; i < connections.length; i++) { var connection = connections[i]; if (connection.gater == null) { json.links.push({ source: this.nodes.indexOf(connection.from), target: this.nodes.indexOf(connection.to), weight: connection.weight }); } else { // Add a gater 'node' var index = json.nodes.length; json.nodes.push({ id: index, activation: connection.gater.activation, name: 'GATE' }); json.links.push({ source: this.nodes.indexOf(connection.from), target: index, weight: 1 / 2 * connection.weight }); json.links.push({ source: index, target: this.nodes.indexOf(connection.to), weight: 1 / 2 * connection.weight }); json.links.push({ source: this.nodes.indexOf(connection.gater), target: index, weight: connection.gater.activation, gate: true }); } } return json; }, /** * Convert the network to a json object */ toJSON: function () { var json = { nodes: [], connections: [], input: this.input, output: this.output, dropout: this.dropout }; // So we don't have to use expensive .indexOf() var i; for (i = 0; i < this.nodes.length; i++) { this.nodes[i].index = i; } for (i = 0; i < this.nodes.length; i++) { let node = this.nodes[i]; let tojson = node.toJSON(); tojson.index = i; json.nodes.push(tojson); if (node.connections.self.weight !== 0) { let tojson = node.connections.self.toJSON(); tojson.from = i; tojson.to = i; tojson.gater = node.connections.self.gater != null ? node.connections.self.gater.index : null; json.connections.push(tojson); } } for (i = 0; i < this.connections.length; i++) { let conn = this.connections[i]; let tojson = conn.toJSON(); tojson.from = conn.from.index; tojson.to = conn.to.index; tojson.gater = conn.gater != null ? conn.gater.index : null; json.connections.push(tojson); } return json; }, /** * Sets the value of a property for every node in this network */ set: function (values) { for (var i = 0; i < this.nodes.length; i++) { this.nodes[i].bias = values.bias || this.nodes[i].bias; this.nodes[i].squash = values.squash || this.nodes[i].squash; } }, /** * Evolves the network to reach a lower error on a dataset */ evolve: async function (set, options) { if (set[0].input.length !== this.input || set[0].output.length !== this.output) { throw new Error('Dataset input/output size should be same as network input/output size!'); } // Read the options options = options || {}; var targetError = typeof options.error !== 'undefined' ? options.error : 0.05; var growth = typeof options.growth !== 'undefined' ? options.growth : 0.0001; var cost = options.cost || methods.cost.MSE; var amount = options.amount || 1; var threads = options.threads; if (typeof threads === 'undefined') { if (typeof window === 'undefined') { // Node.js threads = require('os').cpus().length; } else { // Browser threads = navigator.hardwareConcurrency; } } var start = Date.now(); if (typeof options.iterations === 'undefined' && typeof options.error === 'undefined') { throw new Error('At least one of the following options must be specified: error, iterations'); } else if (typeof options.error === 'undefined') { targetError = -1; // run until iterations } else if (typeof options.iterations === 'undefined') { options.iterations = 0; // run until target error } var fitnessFunction; if (threads === 1) { // Create the fitness function fitnessFunction = function (genome) { var score = 0; for (var i = 0; i < amount; i++) { score -= genome.test(set, cost).error; } score -= (genome.nodes.length - genome.input - genome.output + genome.connections.length + genome.gates.length) * growth; score = isNaN(score) ? -Infinity : score; // this can cause problems with fitness proportionate selection return score / amount; }; } else { // Serialize the dataset var converted = multi.serializeDataSet(set); // Create workers, send datasets var workers = []; if (typeof window === 'undefined') { for (var i = 0; i < threads; i++) { workers.push(new multi.workers.node.TestWorker(converted, cost)); } } else { for (var i = 0; i < threads; i++) { workers.push(new multi.workers.browser.TestWorker(converted, cost)); } } fitnessFunction = function (population) { return new Promise((resolve, reject) => { // Create a queue var queue = population.slice(); var done = 0; // Start worker function var startWorker = function (worker) { if (!queue.length) { if (++done === threads) resolve(); return; } var genome = queue.shift(); worker.evaluate(genome).then(function (result) { genome.score = -result; genome.score -= (genome.nodes.length - genome.input - genome.output + genome.connections.length + genome.gates.length) * growth; genome.score = isNaN(parseFloat(result)) ? -Infinity : genome.score; startWorker(worker); }); }; for (var i = 0; i < workers.length; i++) { startWorker(workers[i]); } }); }; options.fitnessPopulation = true; } // Intialise the NEAT instance options.network = this; var neat = new Neat(this.input, this.output, fitnessFunction, options); var error = -Infinity; var bestFitness = -Infinity; var bestGenome; while (error < -targetError && (options.iterations === 0 || neat.generation < options.iterations)) { let fittest = await neat.evolve(); let fitness = fittest.score; error = fitness + (fittest.nodes.length - fittest.input - fittest.output + fittest.connections.length + fittest.gates.length) * growth; if (fitness > bestFitness) { bestFitness = fitness; bestGenome = fittest; } if (options.log && neat.generation % options.log === 0) { console.log('iteration', neat.generation, 'fitness', fitness, 'error', -error); } if (options.schedule && neat.generation % options.schedule.iterations === 0) { options.schedule.function({ fitness: fitness, error: -error, iteration: neat.generation }); } } if (threads > 1) { for (var i = 0; i < workers.length; i++) workers[i].terminate(); } if (typeof bestGenome !== 'undefined') { this.nodes = bestGenome.nodes; this.connections = bestGenome.connections; this.selfconns = bestGenome.selfconns; this.gates = bestGenome.gates; if (options.clear) this.clear(); } return { error: -error, iterations: neat.generation, time: Date.now() - start }; }, /** * Creates a standalone function of the network which can be run without the * need of a library */ standalone: function () { var present = []; var activations = []; var states = []; var lines = []; var functions = []; var i; for (i = 0; i < this.input; i++) { var node = this.nodes[i]; activations.push(node.activation); states.push(node.state); } lines.push('for(var i = 0; i < input.length; i++) A[i] = input[i];'); // So we don't have to use expensive .indexOf() for (i = 0; i < this.nodes.length; i++) { this.nodes[i].index = i; } for (i = this.input; i < this.nodes.length; i++) { let node = this.nodes[i]; activations.push(node.activation); states.push(node.state); var functionIndex = present.indexOf(node.squash.name); if (functionIndex === -1) { functionIndex = present.length; present.push(node.squash.name); functions.push(node.squash.toString()); } var incoming = []; for (var j = 0; j < node.connections.in.length; j++) { var conn = node.connections.in[j]; var computation = `A[${conn.from.index}] * ${conn.weight}`; if (conn.gater != null) { computation += ` * A[${conn.gater.index}]`; } incoming.push(computation); } if (node.connections.self.weight) { let conn = node.connections.self; let computation = `S[${i}] * ${conn.weight}`; if (conn.gater != null) { computation += ` * A[${conn.gater.index}]`; } incoming.push(computation); } var line1 = `S[${i}] = ${incoming.join(' + ')} + ${node.bias};`; var line2 = `A[${i}] = F[${functionIndex}](S[${i}])${!node.mask ? ' * ' + node.mask : ''};`; lines.push(line1); lines.push(line2); } var output = []; for (i = this.nodes.length - this.output; i < this.nodes.length; i++) { output.push(`A[${i}]`); } output = `return [${output.join(',')}];`; lines.push(output); var total = ''; total += `var F = [${functions.toString()}];\r\n`; total += `var A = [${activations.toString()}];\r\n`; total += `var S = [${states.toString()}];\r\n`; total += `function activate(input){\r\n${lines.join('\r\n')}\r\n}`; return total; }, /** * Serialize to send to workers efficiently */ serialize: function () { var activations = []; var states = []; var conns = []; var squashes = [ 'LOGISTIC', 'TANH', 'IDENTITY', 'STEP', 'RELU', 'SOFTSIGN', 'SINUSOID', 'GAUSSIAN', 'BENT_IDENTITY', 'BIPOLAR', 'BIPOLAR_SIGMOID', 'HARD_TANH', 'ABSOLUTE', 'INVERSE', 'SELU' ]; conns.push(this.input); conns.push(this.output); var i; for (i = 0; i < this.nodes.length; i++) { let node = this.nodes[i]; node.index = i; activations.push(node.activation); states.push(node.state); } for (i = this.input; i < this.nodes.length; i++) { let node = this.nodes[i]; conns.push(node.index); conns.push(node.bias); conns.push(squashes.indexOf(node.squash.name)); conns.push(node.connections.self.weight); conns.push(node.connections.self.gater == null ? -1 : node.connections.self.gater.index); for (var j = 0; j < node.connections.in.length; j++) { let conn = node.connections.in[j]; conns.push(conn.from.index); conns.push(conn.weight); conns.push(conn.gater == null ? -1 : conn.gater.index); } conns.push(-2); // stop token -> next node } return [activations, states, conns]; } }; /** * Convert a json object to a network */ Network.fromJSON = function (json) { var network = new Network(json.input, json.output); network.dropout = json.dropout; network.nodes = []; network.connections = []; var i; for (i = 0; i < json.nodes.length; i++) { network.nodes.push(Node.fromJSON(json.nodes[i])); } for (i = 0; i < json.connections.length; i++) { var conn = json.connections[i]; var connection = network.connect(network.nodes[conn.from], network.nodes[conn.to])[0]; connection.weight = conn.weight; if (conn.gater != null) { network.gate(network.nodes[conn.gater], connection); } } return network; }; /** * Merge two networks into one */ Network.merge = function (network1, network2) { // Create a copy of the networks network1 = Network.fromJSON(network1.toJSON()); network2 = Network.fromJSON(network2.toJSON()); // Check if output and input size are the same if (network1.output !== network2.input) { throw new Error('Output size of network1 should be the same as the input size of network2!'); } // Redirect all connections from network2 input from network1 output var i; for (i = 0; i < network2.connections.length; i++) { let conn = network2.connections[i]; if (conn.from.type === 'input') { let index = network2.nodes.indexOf(conn.from); // redirect conn.from = network1.nodes[network1.nodes.length - 1 - index]; } } // Delete input nodes of network2 for (i = network2.input - 1; i >= 0; i--) { network2.nodes.splice(i, 1); } // Change the node type of network1's output nodes (now hidden) for (i = network1.nodes.length - network1.output; i < network1.nodes.length; i++) { network1.nodes[i].type = 'hidden'; } // Create one network from both networks network1.connections = network1.connections.concat(network2.connections); network1.nodes = network1.nodes.concat(network2.nodes); return network1; }; /** * Create an offspring from two parent networks */ Network.crossOver = function (network1, network2, equal) { if (network1.input !== network2.input || network1.output !== network2.output) { throw new Error("Networks don't have the same input/output size!"); } // Initialise offspring var offspring = new Network(network1.input, network1.output); offspring.connections = []; offspring.nodes = []; // Save scores and create a copy var score1 = network1.score || 0; var score2 = network2.score || 0; // Determine offspring node size var size; if (equal || score1 === score2) { let max = Math.max(network1.nodes.length, network2.nodes.length); let min = Math.min(network1.nodes.length, network2.nodes.length); size = Math.floor(Math.random() * (max - min + 1) + min); } else if (score1 > score2) { size = network1.nodes.length; } else { size = network2.nodes.length; } // Rename some variables for easier reading var outputSize = network1.output; // Set indexes so we don't need indexOf var i; for (i = 0; i < network1.nodes.length; i++) { network1.nodes[i].index = i; } for (i = 0; i < network2.nodes.length; i++) { network2.nodes[i].index = i; } // Assign nodes from parents to offspring for (i = 0; i < size; i++) { // Determine if an output node is needed var node; if (i < size - outputSize) { let random = Math.random(); node = random >= 0.5 ? network1.nodes[i] : network2.nodes[i]; let other = random < 0.5 ? network1.nodes[i] : network2.nodes[i]; if (typeof node === 'undefined' || node.type === 'output') { node = other; } } else { if (Math.random() >= 0.5) { node = network1.nodes[network1.nodes.length + i - size]; } else { node = network2.nodes[network2.nodes.length + i - size]; } } var newNode = new Node(); newNode.bias = node.bias; newNode.squash = node.squash; newNode.type = node.type; offspring.nodes.push(newNode); } // Create arrays of connection genes var n1conns = {}; var n2conns = {}; // Normal connections for (i = 0; i < network1.connections.length; i++) { let conn = network1.connections[i]; let data = { weight: conn.weight, from: conn.from.index, to: conn.to.index, gater: conn.gater != null ? conn.gater.index : -1 }; n1conns[Connection.innovationID(data.from, data.to)] = data; } // Selfconnections for (i = 0; i < network1.selfconns.length; i++) { let conn = network1.selfconns[i]; let data = { weight: conn.weight, from: conn.from.index, to: conn.to.index, gater: conn.gater != null ? conn.gater.index : -1 }; n1conns[Connection.innovationID(data.from, data.to)] = data; } // Normal connections for (i = 0; i < network2.connections.length; i++) { let conn = network2.connections[i]; let data = { weight: conn.weight, from: conn.from.index, to: conn.to.index, gater: conn.gater != null ? conn.gater.index : -1 }; n2conns[Connection.innovationID(data.from, data.to)] = data; } // Selfconnections for (i = 0; i < network2.selfconns.length; i++) { let conn = network2.selfconns[i]; let data = { weight: conn.weight, from: conn.from.index, to: conn.to.index, gater: conn.gater != null ? conn.gater.index : -1 }; n2conns[Connection.innovationID(data.from, data.to)] = data; } // Split common conn genes from disjoint or excess conn genes var connections = []; var keys1 = Object.keys(n1conns); var keys2 = Object.keys(n2conns); for (i = keys1.length - 1; i >= 0; i--) { // Common gene if (typeof n2conns[keys1[i]] !== 'undefined') { let conn = Math.random() >= 0.5 ? n1conns[keys1[i]] : n2conns[keys1[i]]; connections.push(conn); // Because deleting is expensive, just set it to some value n2conns[keys1[i]] = undefined; } else if (score1 >= score2 || equal) { connections.push(n1conns[keys1[i]]); } } // Excess/disjoint gene if (score2 >= score1 || equal) { for (i = 0; i < keys2.length; i++) { if (typeof n2conns[keys2[i]] !== 'undefined') { connections.push(n2conns[keys2[i]]); } } } // Add common conn genes uniformly for (i = 0; i < connections.length; i++) { let connData = connections[i]; if (connData.to < size && connData.from < size) { let from = offspring.nodes[connData.from]; let to = offspring.nodes[connData.to]; let conn = offspring.connect(from, to)[0]; conn.weight = connData.weight; if (connData.gater !== -1 && connData.gater < size) { offspring.gate(offspring.nodes[connData.gater], conn); } } } return offspring; };