neataptic
Version:
Architecture-free neural network library with genetic algorithm implementations
1,342 lines (1,131 loc) • 42.4 kB
JavaScript
/* Export */
module.exports = Network;
/* Import */
var multi = require('../multithreading/multi');
var methods = require('../methods/methods');
var Connection = require('./connection');
var config = require('../config');
var Neat = require('../neat');
var Node = require('./node');
/* Easier variable naming */
var mutation = methods.mutation;
/*******************************************************************************
NETWORK
*******************************************************************************/
function Network (input, output) {
if (typeof input === 'undefined' || typeof output === 'undefined') {
throw new Error('No input or output size given');
}
this.input = input;
this.output = output;
// Store all the node and connection genes
this.nodes = []; // Stored in activation order
this.connections = [];
this.gates = [];
this.selfconns = [];
// Regularization
this.dropout = 0;
// Create input and output nodes
var i;
for (i = 0; i < this.input + this.output; i++) {
var type = i < this.input ? 'input' : 'output';
this.nodes.push(new Node(type));
}
// Connect input nodes with output nodes directly
for (i = 0; i < this.input; i++) {
for (var j = this.input; j < this.output + this.input; j++) {
// https://stats.stackexchange.com/a/248040/147931
var weight = Math.random() * this.input * Math.sqrt(2 / this.input);
this.connect(this.nodes[i], this.nodes[j], weight);
}
}
}
Network.prototype = {
/**
* Activates the network
*/
activate: function (input, training) {
var output = [];
// Activate nodes chronologically
for (var i = 0; i < this.nodes.length; i++) {
if (this.nodes[i].type === 'input') {
this.nodes[i].activate(input[i]);
} else if (this.nodes[i].type === 'output') {
var activation = this.nodes[i].activate();
output.push(activation);
} else {
if (training) this.nodes[i].mask = Math.random() < this.dropout ? 0 : 1;
this.nodes[i].activate();
}
}
return output;
},
/**
* Activates the network without calculating elegibility traces and such
*/
noTraceActivate: function (input) {
var output = [];
// Activate nodes chronologically
for (var i = 0; i < this.nodes.length; i++) {
if (this.nodes[i].type === 'input') {
this.nodes[i].noTraceActivate(input[i]);
} else if (this.nodes[i].type === 'output') {
var activation = this.nodes[i].noTraceActivate();
output.push(activation);
} else {
this.nodes[i].noTraceActivate();
}
}
return output;
},
/**
* Backpropagate the network
*/
propagate: function (rate, momentum, update, target) {
if (typeof target === 'undefined' || target.length !== this.output) {
throw new Error('Output target length should match network output length');
}
var targetIndex = target.length;
// Propagate output nodes
var i;
for (i = this.nodes.length - 1; i >= this.nodes.length - this.output; i--) {
this.nodes[i].propagate(rate, momentum, update, target[--targetIndex]);
}
// Propagate hidden and input nodes
for (i = this.nodes.length - this.output - 1; i >= this.input; i--) {
this.nodes[i].propagate(rate, momentum, update);
}
},
/**
* Clear the context of the network
*/
clear: function () {
for (var i = 0; i < this.nodes.length; i++) {
this.nodes[i].clear();
}
},
/**
* Connects the from node to the to node
*/
connect: function (from, to, weight) {
var connections = from.connect(to, weight);
for (var i = 0; i < connections.length; i++) {
var connection = connections[i];
if (from !== to) {
this.connections.push(connection);
} else {
this.selfconns.push(connection);
}
}
return connections;
},
/**
* Disconnects the from node from the to node
*/
disconnect: function (from, to) {
// Delete the connection in the network's connection array
var connections = from === to ? this.selfconns : this.connections;
for (var i = 0; i < connections.length; i++) {
var connection = connections[i];
if (connection.from === from && connection.to === to) {
if (connection.gater !== null) this.ungate(connection);
connections.splice(i, 1);
break;
}
}
// Delete the connection at the sending and receiving neuron
from.disconnect(to);
},
/**
* Gate a connection with a node
*/
gate: function (node, connection) {
if (this.nodes.indexOf(node) === -1) {
throw new Error('This node is not part of the network!');
} else if (connection.gater != null) {
if (config.warnings) console.warn('This connection is already gated!');
return;
}
node.gate(connection);
this.gates.push(connection);
},
/**
* Remove the gate of a connection
*/
ungate: function (connection) {
var index = this.gates.indexOf(connection);
if (index === -1) {
throw new Error('This connection is not gated!');
}
this.gates.splice(index, 1);
connection.gater.ungate(connection);
},
/**
* Removes a node from the network
*/
remove: function (node) {
var index = this.nodes.indexOf(node);
if (index === -1) {
throw new Error('This node does not exist in the network!');
}
// Keep track of gaters
var gaters = [];
// Remove selfconnections from this.selfconns
this.disconnect(node, node);
// Get all its inputting nodes
var inputs = [];
for (var i = node.connections.in.length - 1; i >= 0; i--) {
let connection = node.connections.in[i];
if (mutation.SUB_NODE.keep_gates && connection.gater !== null && connection.gater !== node) {
gaters.push(connection.gater);
}
inputs.push(connection.from);
this.disconnect(connection.from, node);
}
// Get all its outputing nodes
var outputs = [];
for (i = node.connections.out.length - 1; i >= 0; i--) {
let connection = node.connections.out[i];
if (mutation.SUB_NODE.keep_gates && connection.gater !== null && connection.gater !== node) {
gaters.push(connection.gater);
}
outputs.push(connection.to);
this.disconnect(node, connection.to);
}
// Connect the input nodes to the output nodes (if not already connected)
var connections = [];
for (i = 0; i < inputs.length; i++) {
let input = inputs[i];
for (var j = 0; j < outputs.length; j++) {
let output = outputs[j];
if (!input.isProjectingTo(output)) {
var conn = this.connect(input, output);
connections.push(conn[0]);
}
}
}
// Gate random connections with gaters
for (i = 0; i < gaters.length; i++) {
if (connections.length === 0) break;
let gater = gaters[i];
let connIndex = Math.floor(Math.random() * connections.length);
this.gate(gater, connections[connIndex]);
connections.splice(connIndex, 1);
}
// Remove gated connections gated by this node
for (i = node.connections.gated.length - 1; i >= 0; i--) {
let conn = node.connections.gated[i];
this.ungate(conn);
}
// Remove selfconnection
this.disconnect(node, node);
// Remove the node from this.nodes
this.nodes.splice(index, 1);
},
/**
* Mutates the network with the given method
*/
mutate: function (method) {
if (typeof method === 'undefined') {
throw new Error('No (correct) mutate method given!');
}
var i, j;
switch (method) {
case mutation.ADD_NODE:
// Look for an existing connection and place a node in between
var connection = this.connections[Math.floor(Math.random() * this.connections.length)];
var gater = connection.gater;
this.disconnect(connection.from, connection.to);
// Insert the new node right before the old connection.to
var toIndex = this.nodes.indexOf(connection.to);
var node = new Node('hidden');
// Random squash function
node.mutate(mutation.MOD_ACTIVATION);
// Place it in this.nodes
var minBound = Math.min(toIndex, this.nodes.length - this.output);
this.nodes.splice(minBound, 0, node);
// Now create two new connections
var newConn1 = this.connect(connection.from, node)[0];
var newConn2 = this.connect(node, connection.to)[0];
// Check if the original connection was gated
if (gater != null) {
this.gate(gater, Math.random() >= 0.5 ? newConn1 : newConn2);
}
break;
case mutation.SUB_NODE:
// Check if there are nodes left to remove
if (this.nodes.length === this.input + this.output) {
if (config.warnings) console.warn('No more nodes left to remove!');
break;
}
// Select a node which isn't an input or output node
var index = Math.floor(Math.random() * (this.nodes.length - this.output - this.input) + this.input);
this.remove(this.nodes[index]);
break;
case mutation.ADD_CONN:
// Create an array of all uncreated (feedforward) connections
var available = [];
for (i = 0; i < this.nodes.length - this.output; i++) {
let node1 = this.nodes[i];
for (j = Math.max(i + 1, this.input); j < this.nodes.length; j++) {
let node2 = this.nodes[j];
if (!node1.isProjectingTo(node2)) available.push([node1, node2]);
}
}
if (available.length === 0) {
if (config.warnings) console.warn('No more connections to be made!');
break;
}
var pair = available[Math.floor(Math.random() * available.length)];
this.connect(pair[0], pair[1]);
break;
case mutation.SUB_CONN:
// List of possible connections that can be removed
var possible = [];
for (i = 0; i < this.connections.length; i++) {
let conn = this.connections[i];
// Check if it is not disabling a node
if (conn.from.connections.out.length > 1 && conn.to.connections.in.length > 1 && this.nodes.indexOf(conn.to) > this.nodes.indexOf(conn.from)) {
possible.push(conn);
}
}
if (possible.length === 0) {
if (config.warnings) console.warn('No connections to remove!');
break;
}
var randomConn = possible[Math.floor(Math.random() * possible.length)];
this.disconnect(randomConn.from, randomConn.to);
break;
case mutation.MOD_WEIGHT:
var allconnections = this.connections.concat(this.selfconns);
var connection = allconnections[Math.floor(Math.random() * allconnections.length)];
var modification = Math.random() * (method.max - method.min) + method.min;
connection.weight += modification;
break;
case mutation.MOD_BIAS:
// Has no effect on input node, so they are excluded
var index = Math.floor(Math.random() * (this.nodes.length - this.input) + this.input);
var node = this.nodes[index];
node.mutate(method);
break;
case mutation.MOD_ACTIVATION:
// Has no effect on input node, so they are excluded
if (!method.mutateOutput && this.input + this.output === this.nodes.length) {
if (config.warnings) console.warn('No nodes that allow mutation of activation function');
break;
}
var index = Math.floor(Math.random() * (this.nodes.length - (method.mutateOutput ? 0 : this.output) - this.input) + this.input);
var node = this.nodes[index];
node.mutate(method);
break;
case mutation.ADD_SELF_CONN:
// Check which nodes aren't selfconnected yet
var possible = [];
for (i = this.input; i < this.nodes.length; i++) {
let node = this.nodes[i];
if (node.connections.self.weight === 0) {
possible.push(node);
}
}
if (possible.length === 0) {
if (config.warnings) console.warn('No more self-connections to add!');
break;
}
// Select a random node
var node = possible[Math.floor(Math.random() * possible.length)];
// Connect it to himself
this.connect(node, node);
break;
case mutation.SUB_SELF_CONN:
if (this.selfconns.length === 0) {
if (config.warnings) console.warn('No more self-connections to remove!');
break;
}
var conn = this.selfconns[Math.floor(Math.random() * this.selfconns.length)];
this.disconnect(conn.from, conn.to);
break;
case mutation.ADD_GATE:
var allconnections = this.connections.concat(this.selfconns);
// Create a list of all non-gated connections
var possible = [];
for (i = 0; i < allconnections.length; i++) {
let conn = allconnections[i];
if (conn.gater === null) {
possible.push(conn);
}
}
if (possible.length === 0) {
if (config.warnings) console.warn('No more connections to gate!');
break;
}
// Select a random gater node and connection, can't be gated by input
var index = Math.floor(Math.random() * (this.nodes.length - this.input) + this.input);
var node = this.nodes[index];
var conn = possible[Math.floor(Math.random() * possible.length)];
// Gate the connection with the node
this.gate(node, conn);
break;
case mutation.SUB_GATE:
// Select a random gated connection
if (this.gates.length === 0) {
if (config.warnings) console.warn('No more connections to ungate!');
break;
}
var index = Math.floor(Math.random() * this.gates.length);
var gatedconn = this.gates[index];
this.ungate(gatedconn);
break;
case mutation.ADD_BACK_CONN:
// Create an array of all uncreated (backfed) connections
var available = [];
for (i = this.input; i < this.nodes.length; i++) {
let node1 = this.nodes[i];
for (j = this.input; j < i; j++) {
let node2 = this.nodes[j];
if (!node1.isProjectingTo(node2)) available.push([node1, node2]);
}
}
if (available.length === 0) {
if (config.warnings) console.warn('No more connections to be made!');
break;
}
var pair = available[Math.floor(Math.random() * available.length)];
this.connect(pair[0], pair[1]);
break;
case mutation.SUB_BACK_CONN:
// List of possible connections that can be removed
var possible = [];
for (i = 0; i < this.connections.length; i++) {
let conn = this.connections[i];
// Check if it is not disabling a node
if (conn.from.connections.out.length > 1 && conn.to.connections.in.length > 1 && this.nodes.indexOf(conn.from) > this.nodes.indexOf(conn.to)) {
possible.push(conn);
}
}
if (possible.length === 0) {
if (config.warnings) console.warn('No connections to remove!');
break;
}
var randomConn = possible[Math.floor(Math.random() * possible.length)];
this.disconnect(randomConn.from, randomConn.to);
break;
case mutation.SWAP_NODES:
// Has no effect on input node, so they are excluded
if ((method.mutateOutput && this.nodes.length - this.input < 2) ||
(!method.mutateOutput && this.nodes.length - this.input - this.output < 2)) {
if (config.warnings) console.warn('No nodes that allow swapping of bias and activation function');
break;
}
var index = Math.floor(Math.random() * (this.nodes.length - (method.mutateOutput ? 0 : this.output) - this.input) + this.input);
var node1 = this.nodes[index];
index = Math.floor(Math.random() * (this.nodes.length - (method.mutateOutput ? 0 : this.output) - this.input) + this.input);
var node2 = this.nodes[index];
var biasTemp = node1.bias;
var squashTemp = node1.squash;
node1.bias = node2.bias;
node1.squash = node2.squash;
node2.bias = biasTemp;
node2.squash = squashTemp;
break;
}
},
/**
* Train the given set to this network
*/
train: function (set, options) {
if (set[0].input.length !== this.input || set[0].output.length !== this.output) {
throw new Error('Dataset input/output size should be same as network input/output size!');
}
options = options || {};
// Warning messages
if (typeof options.rate === 'undefined') {
if (config.warnings) console.warn('Using default learning rate, please define a rate!');
}
if (typeof options.iterations === 'undefined') {
if (config.warnings) console.warn('No target iterations given, running until error is reached!');
}
// Read the options
var targetError = options.error || 0.05;
var cost = options.cost || methods.cost.MSE;
var baseRate = options.rate || 0.3;
var dropout = options.dropout || 0;
var momentum = options.momentum || 0;
var batchSize = options.batchSize || 1; // online learning
var ratePolicy = options.ratePolicy || methods.rate.FIXED();
var start = Date.now();
if (batchSize > set.length) {
throw new Error('Batch size must be smaller or equal to dataset length!');
} else if (typeof options.iterations === 'undefined' && typeof options.error === 'undefined') {
throw new Error('At least one of the following options must be specified: error, iterations');
} else if (typeof options.error === 'undefined') {
targetError = -1; // run until iterations
} else if (typeof options.iterations === 'undefined') {
options.iterations = 0; // run until target error
}
// Save to network
this.dropout = dropout;
if (options.crossValidate) {
let numTrain = Math.ceil((1 - options.crossValidate.testSize) * set.length);
var trainSet = set.slice(0, numTrain);
var testSet = set.slice(numTrain);
}
// Loops the training process
var currentRate = baseRate;
var iteration = 0;
var error = 1;
var i, j, x;
while (error > targetError && (options.iterations === 0 || iteration < options.iterations)) {
if (options.crossValidate && error <= options.crossValidate.testError) break;
iteration++;
// Update the rate
currentRate = ratePolicy(baseRate, iteration);
// Checks if cross validation is enabled
if (options.crossValidate) {
this._trainSet(trainSet, batchSize, currentRate, momentum, cost);
if (options.clear) this.clear();
error = this.test(testSet, cost).error;
if (options.clear) this.clear();
} else {
error = this._trainSet(set, batchSize, currentRate, momentum, cost);
if (options.clear) this.clear();
}
// Checks for options such as scheduled logs and shuffling
if (options.shuffle) {
for (j, x, i = set.length; i; j = Math.floor(Math.random() * i), x = set[--i], set[i] = set[j], set[j] = x);
}
if (options.log && iteration % options.log === 0) {
console.log('iteration', iteration, 'error', error, 'rate', currentRate);
}
if (options.schedule && iteration % options.schedule.iterations === 0) {
options.schedule.function({ error: error, iteration: iteration });
}
}
if (options.clear) this.clear();
if (dropout) {
for (i = 0; i < this.nodes.length; i++) {
if (this.nodes[i].type === 'hidden' || this.nodes[i].type === 'constant') {
this.nodes[i].mask = 1 - this.dropout;
}
}
}
return {
error: error,
iterations: iteration,
time: Date.now() - start
};
},
/**
* Performs one training epoch and returns the error
* private function used in this.train
*/
_trainSet: function (set, batchSize, currentRate, momentum, costFunction) {
var errorSum = 0;
for (var i = 0; i < set.length; i++) {
var input = set[i].input;
var target = set[i].output;
var update = !!((i + 1) % batchSize === 0 || (i + 1) === set.length);
var output = this.activate(input, true);
this.propagate(currentRate, momentum, update, target);
errorSum += costFunction(target, output);
}
return errorSum / set.length;
},
/**
* Tests a set and returns the error and elapsed time
*/
test: function (set, cost = methods.cost.MSE) {
// Check if dropout is enabled, set correct mask
var i;
if (this.dropout) {
for (i = 0; i < this.nodes.length; i++) {
if (this.nodes[i].type === 'hidden' || this.nodes[i].type === 'constant') {
this.nodes[i].mask = 1 - this.dropout;
}
}
}
var error = 0;
var start = Date.now();
for (i = 0; i < set.length; i++) {
let input = set[i].input;
let target = set[i].output;
let output = this.noTraceActivate(input);
error += cost(target, output);
}
error /= set.length;
var results = {
error: error,
time: Date.now() - start
};
return results;
},
/**
* Creates a json that can be used to create a graph with d3 and webcola
*/
graph: function (width, height) {
var input = 0;
var output = 0;
var json = {
nodes: [],
links: [],
constraints: [{
type: 'alignment',
axis: 'x',
offsets: []
}, {
type: 'alignment',
axis: 'y',
offsets: []
}]
};
var i;
for (i = 0; i < this.nodes.length; i++) {
var node = this.nodes[i];
if (node.type === 'input') {
if (this.input === 1) {
json.constraints[0].offsets.push({
node: i,
offset: 0
});
} else {
json.constraints[0].offsets.push({
node: i,
offset: 0.8 * width / (this.input - 1) * input++
});
}
json.constraints[1].offsets.push({
node: i,
offset: 0
});
} else if (node.type === 'output') {
if (this.output === 1) {
json.constraints[0].offsets.push({
node: i,
offset: 0
});
} else {
json.constraints[0].offsets.push({
node: i,
offset: 0.8 * width / (this.output - 1) * output++
});
}
json.constraints[1].offsets.push({
node: i,
offset: -0.8 * height
});
}
json.nodes.push({
id: i,
name: node.type === 'hidden' ? node.squash.name : node.type.toUpperCase(),
activation: node.activation,
bias: node.bias
});
}
var connections = this.connections.concat(this.selfconns);
for (i = 0; i < connections.length; i++) {
var connection = connections[i];
if (connection.gater == null) {
json.links.push({
source: this.nodes.indexOf(connection.from),
target: this.nodes.indexOf(connection.to),
weight: connection.weight
});
} else {
// Add a gater 'node'
var index = json.nodes.length;
json.nodes.push({
id: index,
activation: connection.gater.activation,
name: 'GATE'
});
json.links.push({
source: this.nodes.indexOf(connection.from),
target: index,
weight: 1 / 2 * connection.weight
});
json.links.push({
source: index,
target: this.nodes.indexOf(connection.to),
weight: 1 / 2 * connection.weight
});
json.links.push({
source: this.nodes.indexOf(connection.gater),
target: index,
weight: connection.gater.activation,
gate: true
});
}
}
return json;
},
/**
* Convert the network to a json object
*/
toJSON: function () {
var json = {
nodes: [],
connections: [],
input: this.input,
output: this.output,
dropout: this.dropout
};
// So we don't have to use expensive .indexOf()
var i;
for (i = 0; i < this.nodes.length; i++) {
this.nodes[i].index = i;
}
for (i = 0; i < this.nodes.length; i++) {
let node = this.nodes[i];
let tojson = node.toJSON();
tojson.index = i;
json.nodes.push(tojson);
if (node.connections.self.weight !== 0) {
let tojson = node.connections.self.toJSON();
tojson.from = i;
tojson.to = i;
tojson.gater = node.connections.self.gater != null ? node.connections.self.gater.index : null;
json.connections.push(tojson);
}
}
for (i = 0; i < this.connections.length; i++) {
let conn = this.connections[i];
let tojson = conn.toJSON();
tojson.from = conn.from.index;
tojson.to = conn.to.index;
tojson.gater = conn.gater != null ? conn.gater.index : null;
json.connections.push(tojson);
}
return json;
},
/**
* Sets the value of a property for every node in this network
*/
set: function (values) {
for (var i = 0; i < this.nodes.length; i++) {
this.nodes[i].bias = values.bias || this.nodes[i].bias;
this.nodes[i].squash = values.squash || this.nodes[i].squash;
}
},
/**
* Evolves the network to reach a lower error on a dataset
*/
evolve: async function (set, options) {
if (set[0].input.length !== this.input || set[0].output.length !== this.output) {
throw new Error('Dataset input/output size should be same as network input/output size!');
}
// Read the options
options = options || {};
var targetError = typeof options.error !== 'undefined' ? options.error : 0.05;
var growth = typeof options.growth !== 'undefined' ? options.growth : 0.0001;
var cost = options.cost || methods.cost.MSE;
var amount = options.amount || 1;
var threads = options.threads;
if (typeof threads === 'undefined') {
if (typeof window === 'undefined') { // Node.js
threads = require('os').cpus().length;
} else { // Browser
threads = navigator.hardwareConcurrency;
}
}
var start = Date.now();
if (typeof options.iterations === 'undefined' && typeof options.error === 'undefined') {
throw new Error('At least one of the following options must be specified: error, iterations');
} else if (typeof options.error === 'undefined') {
targetError = -1; // run until iterations
} else if (typeof options.iterations === 'undefined') {
options.iterations = 0; // run until target error
}
var fitnessFunction;
if (threads === 1) {
// Create the fitness function
fitnessFunction = function (genome) {
var score = 0;
for (var i = 0; i < amount; i++) {
score -= genome.test(set, cost).error;
}
score -= (genome.nodes.length - genome.input - genome.output + genome.connections.length + genome.gates.length) * growth;
score = isNaN(score) ? -Infinity : score; // this can cause problems with fitness proportionate selection
return score / amount;
};
} else {
// Serialize the dataset
var converted = multi.serializeDataSet(set);
// Create workers, send datasets
var workers = [];
if (typeof window === 'undefined') {
for (var i = 0; i < threads; i++) {
workers.push(new multi.workers.node.TestWorker(converted, cost));
}
} else {
for (var i = 0; i < threads; i++) {
workers.push(new multi.workers.browser.TestWorker(converted, cost));
}
}
fitnessFunction = function (population) {
return new Promise((resolve, reject) => {
// Create a queue
var queue = population.slice();
var done = 0;
// Start worker function
var startWorker = function (worker) {
if (!queue.length) {
if (++done === threads) resolve();
return;
}
var genome = queue.shift();
worker.evaluate(genome).then(function (result) {
genome.score = -result;
genome.score -= (genome.nodes.length - genome.input - genome.output +
genome.connections.length + genome.gates.length) * growth;
genome.score = isNaN(parseFloat(result)) ? -Infinity : genome.score;
startWorker(worker);
});
};
for (var i = 0; i < workers.length; i++) {
startWorker(workers[i]);
}
});
};
options.fitnessPopulation = true;
}
// Intialise the NEAT instance
options.network = this;
var neat = new Neat(this.input, this.output, fitnessFunction, options);
var error = -Infinity;
var bestFitness = -Infinity;
var bestGenome;
while (error < -targetError && (options.iterations === 0 || neat.generation < options.iterations)) {
let fittest = await neat.evolve();
let fitness = fittest.score;
error = fitness + (fittest.nodes.length - fittest.input - fittest.output + fittest.connections.length + fittest.gates.length) * growth;
if (fitness > bestFitness) {
bestFitness = fitness;
bestGenome = fittest;
}
if (options.log && neat.generation % options.log === 0) {
console.log('iteration', neat.generation, 'fitness', fitness, 'error', -error);
}
if (options.schedule && neat.generation % options.schedule.iterations === 0) {
options.schedule.function({ fitness: fitness, error: -error, iteration: neat.generation });
}
}
if (threads > 1) {
for (var i = 0; i < workers.length; i++) workers[i].terminate();
}
if (typeof bestGenome !== 'undefined') {
this.nodes = bestGenome.nodes;
this.connections = bestGenome.connections;
this.selfconns = bestGenome.selfconns;
this.gates = bestGenome.gates;
if (options.clear) this.clear();
}
return {
error: -error,
iterations: neat.generation,
time: Date.now() - start
};
},
/**
* Creates a standalone function of the network which can be run without the
* need of a library
*/
standalone: function () {
var present = [];
var activations = [];
var states = [];
var lines = [];
var functions = [];
var i;
for (i = 0; i < this.input; i++) {
var node = this.nodes[i];
activations.push(node.activation);
states.push(node.state);
}
lines.push('for(var i = 0; i < input.length; i++) A[i] = input[i];');
// So we don't have to use expensive .indexOf()
for (i = 0; i < this.nodes.length; i++) {
this.nodes[i].index = i;
}
for (i = this.input; i < this.nodes.length; i++) {
let node = this.nodes[i];
activations.push(node.activation);
states.push(node.state);
var functionIndex = present.indexOf(node.squash.name);
if (functionIndex === -1) {
functionIndex = present.length;
present.push(node.squash.name);
functions.push(node.squash.toString());
}
var incoming = [];
for (var j = 0; j < node.connections.in.length; j++) {
var conn = node.connections.in[j];
var computation = `A[${conn.from.index}] * ${conn.weight}`;
if (conn.gater != null) {
computation += ` * A[${conn.gater.index}]`;
}
incoming.push(computation);
}
if (node.connections.self.weight) {
let conn = node.connections.self;
let computation = `S[${i}] * ${conn.weight}`;
if (conn.gater != null) {
computation += ` * A[${conn.gater.index}]`;
}
incoming.push(computation);
}
var line1 = `S[${i}] = ${incoming.join(' + ')} + ${node.bias};`;
var line2 = `A[${i}] = F[${functionIndex}](S[${i}])${!node.mask ? ' * ' + node.mask : ''};`;
lines.push(line1);
lines.push(line2);
}
var output = [];
for (i = this.nodes.length - this.output; i < this.nodes.length; i++) {
output.push(`A[${i}]`);
}
output = `return [${output.join(',')}];`;
lines.push(output);
var total = '';
total += `var F = [${functions.toString()}];\r\n`;
total += `var A = [${activations.toString()}];\r\n`;
total += `var S = [${states.toString()}];\r\n`;
total += `function activate(input){\r\n${lines.join('\r\n')}\r\n}`;
return total;
},
/**
* Serialize to send to workers efficiently
*/
serialize: function () {
var activations = [];
var states = [];
var conns = [];
var squashes = [
'LOGISTIC', 'TANH', 'IDENTITY', 'STEP', 'RELU', 'SOFTSIGN', 'SINUSOID',
'GAUSSIAN', 'BENT_IDENTITY', 'BIPOLAR', 'BIPOLAR_SIGMOID', 'HARD_TANH',
'ABSOLUTE', 'INVERSE', 'SELU'
];
conns.push(this.input);
conns.push(this.output);
var i;
for (i = 0; i < this.nodes.length; i++) {
let node = this.nodes[i];
node.index = i;
activations.push(node.activation);
states.push(node.state);
}
for (i = this.input; i < this.nodes.length; i++) {
let node = this.nodes[i];
conns.push(node.index);
conns.push(node.bias);
conns.push(squashes.indexOf(node.squash.name));
conns.push(node.connections.self.weight);
conns.push(node.connections.self.gater == null ? -1 : node.connections.self.gater.index);
for (var j = 0; j < node.connections.in.length; j++) {
let conn = node.connections.in[j];
conns.push(conn.from.index);
conns.push(conn.weight);
conns.push(conn.gater == null ? -1 : conn.gater.index);
}
conns.push(-2); // stop token -> next node
}
return [activations, states, conns];
}
};
/**
* Convert a json object to a network
*/
Network.fromJSON = function (json) {
var network = new Network(json.input, json.output);
network.dropout = json.dropout;
network.nodes = [];
network.connections = [];
var i;
for (i = 0; i < json.nodes.length; i++) {
network.nodes.push(Node.fromJSON(json.nodes[i]));
}
for (i = 0; i < json.connections.length; i++) {
var conn = json.connections[i];
var connection = network.connect(network.nodes[conn.from], network.nodes[conn.to])[0];
connection.weight = conn.weight;
if (conn.gater != null) {
network.gate(network.nodes[conn.gater], connection);
}
}
return network;
};
/**
* Merge two networks into one
*/
Network.merge = function (network1, network2) {
// Create a copy of the networks
network1 = Network.fromJSON(network1.toJSON());
network2 = Network.fromJSON(network2.toJSON());
// Check if output and input size are the same
if (network1.output !== network2.input) {
throw new Error('Output size of network1 should be the same as the input size of network2!');
}
// Redirect all connections from network2 input from network1 output
var i;
for (i = 0; i < network2.connections.length; i++) {
let conn = network2.connections[i];
if (conn.from.type === 'input') {
let index = network2.nodes.indexOf(conn.from);
// redirect
conn.from = network1.nodes[network1.nodes.length - 1 - index];
}
}
// Delete input nodes of network2
for (i = network2.input - 1; i >= 0; i--) {
network2.nodes.splice(i, 1);
}
// Change the node type of network1's output nodes (now hidden)
for (i = network1.nodes.length - network1.output; i < network1.nodes.length; i++) {
network1.nodes[i].type = 'hidden';
}
// Create one network from both networks
network1.connections = network1.connections.concat(network2.connections);
network1.nodes = network1.nodes.concat(network2.nodes);
return network1;
};
/**
* Create an offspring from two parent networks
*/
Network.crossOver = function (network1, network2, equal) {
if (network1.input !== network2.input || network1.output !== network2.output) {
throw new Error("Networks don't have the same input/output size!");
}
// Initialise offspring
var offspring = new Network(network1.input, network1.output);
offspring.connections = [];
offspring.nodes = [];
// Save scores and create a copy
var score1 = network1.score || 0;
var score2 = network2.score || 0;
// Determine offspring node size
var size;
if (equal || score1 === score2) {
let max = Math.max(network1.nodes.length, network2.nodes.length);
let min = Math.min(network1.nodes.length, network2.nodes.length);
size = Math.floor(Math.random() * (max - min + 1) + min);
} else if (score1 > score2) {
size = network1.nodes.length;
} else {
size = network2.nodes.length;
}
// Rename some variables for easier reading
var outputSize = network1.output;
// Set indexes so we don't need indexOf
var i;
for (i = 0; i < network1.nodes.length; i++) {
network1.nodes[i].index = i;
}
for (i = 0; i < network2.nodes.length; i++) {
network2.nodes[i].index = i;
}
// Assign nodes from parents to offspring
for (i = 0; i < size; i++) {
// Determine if an output node is needed
var node;
if (i < size - outputSize) {
let random = Math.random();
node = random >= 0.5 ? network1.nodes[i] : network2.nodes[i];
let other = random < 0.5 ? network1.nodes[i] : network2.nodes[i];
if (typeof node === 'undefined' || node.type === 'output') {
node = other;
}
} else {
if (Math.random() >= 0.5) {
node = network1.nodes[network1.nodes.length + i - size];
} else {
node = network2.nodes[network2.nodes.length + i - size];
}
}
var newNode = new Node();
newNode.bias = node.bias;
newNode.squash = node.squash;
newNode.type = node.type;
offspring.nodes.push(newNode);
}
// Create arrays of connection genes
var n1conns = {};
var n2conns = {};
// Normal connections
for (i = 0; i < network1.connections.length; i++) {
let conn = network1.connections[i];
let data = {
weight: conn.weight,
from: conn.from.index,
to: conn.to.index,
gater: conn.gater != null ? conn.gater.index : -1
};
n1conns[Connection.innovationID(data.from, data.to)] = data;
}
// Selfconnections
for (i = 0; i < network1.selfconns.length; i++) {
let conn = network1.selfconns[i];
let data = {
weight: conn.weight,
from: conn.from.index,
to: conn.to.index,
gater: conn.gater != null ? conn.gater.index : -1
};
n1conns[Connection.innovationID(data.from, data.to)] = data;
}
// Normal connections
for (i = 0; i < network2.connections.length; i++) {
let conn = network2.connections[i];
let data = {
weight: conn.weight,
from: conn.from.index,
to: conn.to.index,
gater: conn.gater != null ? conn.gater.index : -1
};
n2conns[Connection.innovationID(data.from, data.to)] = data;
}
// Selfconnections
for (i = 0; i < network2.selfconns.length; i++) {
let conn = network2.selfconns[i];
let data = {
weight: conn.weight,
from: conn.from.index,
to: conn.to.index,
gater: conn.gater != null ? conn.gater.index : -1
};
n2conns[Connection.innovationID(data.from, data.to)] = data;
}
// Split common conn genes from disjoint or excess conn genes
var connections = [];
var keys1 = Object.keys(n1conns);
var keys2 = Object.keys(n2conns);
for (i = keys1.length - 1; i >= 0; i--) {
// Common gene
if (typeof n2conns[keys1[i]] !== 'undefined') {
let conn = Math.random() >= 0.5 ? n1conns[keys1[i]] : n2conns[keys1[i]];
connections.push(conn);
// Because deleting is expensive, just set it to some value
n2conns[keys1[i]] = undefined;
} else if (score1 >= score2 || equal) {
connections.push(n1conns[keys1[i]]);
}
}
// Excess/disjoint gene
if (score2 >= score1 || equal) {
for (i = 0; i < keys2.length; i++) {
if (typeof n2conns[keys2[i]] !== 'undefined') {
connections.push(n2conns[keys2[i]]);
}
}
}
// Add common conn genes uniformly
for (i = 0; i < connections.length; i++) {
let connData = connections[i];
if (connData.to < size && connData.from < size) {
let from = offspring.nodes[connData.from];
let to = offspring.nodes[connData.to];
let conn = offspring.connect(from, to)[0];
conn.weight = connData.weight;
if (connData.gater !== -1 && connData.gater < size) {
offspring.gate(offspring.nodes[connData.gater], conn);
}
}
}
return offspring;
};