@webwriter/neural-network
Version:
Deep learning visualization for feed-forward networks with custom datasets, training and prediction.
409 lines (371 loc) • 14.8 kB
text/typescript
import type { ReactiveController } from 'lit'
import type { NeuralNetwork } from '@/app'
import type { CLayerConf } from '@/types/c_layer_conf'
import type { CLayerConnectionConf } from '@/types/c_layer_connection_conf'
import { CLayer } from '@/components/network/c_layer'
import { AlertUtils } from '@/utils/alert_utils'
import { InputLayerConf } from '@/types/input_layer_conf'
import { OutputLayerConf } from '@/types/output_layer_conf'
import * as tf from '@tensorflow/tfjs'
export class NetworkController implements ReactiveController {
host: NeuralNetwork
constructor(host: NeuralNetwork) {
this.host = host
host.addController(this)
}
// HOST LIFECYCLE - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
hostConnected() {
// add event listeners for network related events on host
this.host.renderRoot.addEventListener('clear-network', (_e: Event) =>
this.clearNetwork()
)
// listen to layer-conf-created events that layer emit after their static
// factory method was called. now we can give the freshly created layer a
// unique id and add it to the network. we have to listen on window since
// this event can also be triggered by menu options
window.addEventListener(
'layer-conf-created',
(e: CustomEvent<CLayerConf>) => this.addLayer(e.detail)
)
// neurons emit this event when they are rerendered, so we can force the
// layer connections to also rerender
this.host.renderRoot.addEventListener(
'layer-updated',
(e: CustomEvent<number>) => this.updateLayer(e.detail)
)
// a deletion of a layer can be queried by they layers themselves (e.g.
// because no data was assigned to them) or by the UI.
window.addEventListener('query-layer-deletion', (e: CustomEvent<CLayer>) =>
this.removeLayer(e.detail)
)
this.host.renderRoot.addEventListener(
'add-layer-connection',
(e: CustomEvent<{ source: number; target: number }>) =>
this.addLayerConnection(e.detail.source, e.detail.target)
)
this.host.renderRoot.addEventListener('update-layer-confs', (_e: Event) =>
this.updateLayerConfs()
)
this.host.renderRoot.addEventListener(
'remove-layer-connection',
(e: CustomEvent<{ source: number; target: number }>) =>
this.removeLayerConnection(e.detail.source, e.detail.target)
)
// add event listeners for network related keyboard events
window.addEventListener('keyup', (e: KeyboardEvent) => {
this.removeLayerListener(e)
this.duplicateLayerListener(e)
})
window.addEventListener('keydown', (e: KeyboardEvent) => {
this.moveLayerListener(e)
})
}
hostUpdated() {
// as soon as the network component is rendered set the network property to
// it, so that other components can access it
if (!this.host.network && this.host.renderRoot.querySelector('c-network')) {
this.host.network = this.host.renderRoot.querySelector('c-network')
}
}
// METHODS - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// -> ADDING - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// returns an id currently not in the layer. Since new layers are added in the
// end of our layers array and always get higher ids than the previous layers,
// it suffices to take the id of the last layer and add 1 to it to get an
// unused id
private getFreshId(): number {
if (!this.host.layerConfs.length) {
// if we do not have any layer yet, id 1 is not taken for sure
return 1
} else {
// else we get the maximum id and add 1
return (
Math.max(
...this.host.layerConfs.map((layerConf) => layerConf.layerId)
) + 1
)
}
}
addLayer(layerConf: CLayerConf): void {
// get the layer a fresh unused id
layerConf['layerId'] = this.getFreshId()
// assign all unassigned inputs to the layer in case it is an input layer
if (
layerConf.LAYER_TYPE == 'Input' &&
!(<InputLayerConf>layerConf).featureKeys
) {
;(<InputLayerConf>layerConf).featureKeys =
this.host.dataSet.featureDescs.map((featureDesc) => featureDesc.key)
}
// assign the label to the layer in case it is an output layer
else if (layerConf.LAYER_TYPE == 'Output') {
;(<OutputLayerConf>layerConf).labelDesc = this.host.dataSet.labelDesc
}
// get the layer a position if none was specified
if (!layerConf['pos']) {
layerConf['pos'] = this.host.canvas.generatePos()
}
// add the layer to the network
this.host.layerConfs.push(layerConf)
this.host.layerConfs = [...this.host.layerConfs]
}
// checks on keyboard event whether the keyboard shortcut for duplicating a
// layer was pressed and then handles the duplication
duplicateLayerListener(e: KeyboardEvent) {
// 'duplicate layer' event
if (e.ctrlKey && e.code == 'KeyK') {
if (this.host.selected.layer && this.host.selectedEle) {
;(<CLayer>this.host.selectedEle).duplicate()
}
}
}
// -> UPADTING - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// update (trigger an update/rerender of) the layer connections when a layer
// has performed changes and thus been rerenderd
updateLayer(layerId: number): void {
const affectedConnectionConfs = this.host.layerConnectionConfs.filter(
(conConf) => {
return (
conConf.sourceLayerId == layerId || conConf.targetLayerId == layerId
)
}
)
for (const affectedConnectionConf of affectedConnectionConfs) {
this.host.network
.getLayerConnectionByLayerIds(
affectedConnectionConf.sourceLayerId,
affectedConnectionConf.targetLayerId
)
.requestUpdate()
}
}
updateLayerConfs(): void {
this.host.layerConfs = [...this.host.layerConfs]
}
// checks on keyboard event whether the keyboard shortcut for moving a layer
// was pressed and then handles the moving
moveLayerListener(e: KeyboardEvent) {
if (
this.host.selected.layer &&
e.ctrlKey &&
e.shiftKey &&
['ArrowUp', 'ArrowLeft', 'ArrowDown', 'ArrowRight'].includes(e.code)
) {
const layer = this.host.network.getLayerById(
parseInt(this.host.selected.layer)
)
const layerCy = this.host.canvas.cy.getElementById(layer.getCyId())
// move according to pressed key
const SPEED = 10
if (e.code == 'ArrowUp') {
layerCy.shift('y', -SPEED)
} else if (e.code == 'ArrowLeft') {
layerCy.shift('x', -SPEED)
} else if (e.code == 'ArrowDown') {
layerCy.shift('y', SPEED)
} else if (e.code == 'ArrowRight') {
layerCy.shift('x', SPEED)
}
}
}
// -> REMOVING - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// reset the network by resetting the network conf
clearNetwork() {
// TODO replace with event?
// deselect the currently selected element since it will be removed
this.host.selectionController.unselect()
// empty the network
this.host.layerConnectionConfs = []
this.host.layerConfs = []
}
// remove a layer from the network and thus triggers the disconnectedCallback
// function of the layer which handles the removing of the layer itself
removeLayer(layer: CLayer): void {
// only perform action if allowed
if (this.host.editable || this.host.settings.mayAddAndRemoveLayers) {
// remove the connections from and to this layer
for (let i = this.host.layerConnectionConfs.length - 1; i >= 0; i--) {
const conConf = this.host.layerConnectionConfs[i]
if (
conConf.sourceLayerId == layer.conf.layerId ||
conConf.targetLayerId == layer.conf.layerId
) {
this.removeLayerConnection(
conConf.sourceLayerId,
conConf.targetLayerId
)
}
}
// remove the reference to the layer in our layers array
const index = this.host.layerConfs.findIndex((layerConf) => {
return layerConf.layerId == layer.conf.layerId
})
if (index > -1) {
this.host.layerConfs.splice(index, 1)
this.host.layerConfs = [...this.host.layerConfs]
}
// deselect the layer
this.host.selectionController.unselect()
AlertUtils.spawn({
message: `'${layer.getName()}' has been deleted!`,
variant: 'danger',
icon: 'trash',
})
}
}
// checks on keyboard event whether the keyboard shortcut for removing a layer
// was pressed and then handles the removal
removeLayerListener(e: KeyboardEvent) {
// 'remove layer' event
if (e.ctrlKey && e.shiftKey && e.code == 'Backspace') {
// delete layer
if (this.host.selected.layer) {
const layer = this.host.network.getLayerById(
parseInt(this.host.selected.layer)
)
this.removeLayer(layer)
}
// spawn alert when instead a neuron are edge is selected
else if (this.host.selected.neuron) {
AlertUtils.spawn({
message: `You can not delete neurons by now! To adjust the number of neurons in the layer, select the layer and set the number of neurons in the right panel!`,
variant: 'warning',
icon: 'x-circle',
})
} else if (this.host.selected.edge) {
AlertUtils.spawn({
message: `Can not delete edges manually. If you wish to delete all connections between two layers, select one of the affected layers and change its input`,
variant: 'warning',
icon: 'x-circle',
})
}
}
}
// -> CONNECTING - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// adds a layer connection from the layer connections configuration by the ids
// of the source and target layers
addLayerConnection(source: number, target: number): void {
const layerConnectionConf: CLayerConnectionConf = {
sourceLayerId: source,
targetLayerId: target,
}
this.host.layerConnectionConfs.push(layerConnectionConf)
this.host.layerConnectionConfs = [...this.host.layerConnectionConfs]
}
// removes a layer connection from the layer connections configuration by the
// ids of the source and target layers
removeLayerConnection(source: number, target: number): void {
const index = this.host.layerConnectionConfs.findIndex(
(layerConnectionConf) => {
return (
layerConnectionConf.sourceLayerId == source &&
layerConnectionConf.targetLayerId == target
)
}
)
if (index > -1) {
this.host.layerConnectionConfs.splice(index, 1)
this.host.layerConnectionConfs = [...this.host.layerConnectionConfs]
}
}
// -> MODEL - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// ---> BUILD - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
buildModel(): tf.LayersModel {
// check if at least one input layer exists
if (!this.host.network.getInputLayers().length) {
AlertUtils.spawn({
message: 'Your network must contain at least one input layer',
variant: 'warning',
icon: 'x-circle',
})
return null
}
// check if an output layer exists
if (!this.host.network.getOutputLayer()) {
AlertUtils.spawn({
message: 'Your network must contain at least one output layer',
variant: 'warning',
icon: 'x-circle',
})
return null
}
// now we can start building the network iteratively using a queue of layers
// that we initialize with the input layers since they dont need to fulfill
// any preconditions in order to be built.
const buildQueue: CLayer[] = this.host.network.getInputLayers()
// iterate over the build queue but skip layers that have sources which are
// not yet built. They will end up in the build queue later again.
while (buildQueue.length) {
const layer = buildQueue[0]
// skip the layer if not all preceeding layers have been build
if (
this.host.network.getSourcesFor(layer).every((layer) => layer.tensor)
) {
// let the layer build its tensor and add it to its conf
const tensor = layer.build(
this.host.network.getSourcesFor(layer).map((layer) => layer.tensor)
)
this.host.network.tensorConfs.set(layer.conf.layerId, { tensor })
// add all layers the current layer connects to the the queue
this.host.network
.getTargetsFor(layer)
.forEach((layer) => buildQueue.push(layer))
}
// we are done with the current layer, so we remove it from the queue
buildQueue.shift()
}
this.host.network.tensorConfs = new Map(this.host.network.tensorConfs)
// check if there is a connected output layer, else abort (might lead to
// some problems else)
if (!this.host.network.getOutputLayer().tensor) {
AlertUtils.spawn({
message: 'Make sure to have an output layer connected to the network!',
variant: 'warning',
icon: 'x-circle',
})
return null
}
// get the input and output tensors from the resp. layers and create the
// model
const inputs: tf.SymbolicTensor[] = this.host.network
.getInputLayers()
.map((layer) => layer.tensor)
const output: tf.SymbolicTensor = this.host.network.getOutputLayer().tensor
const tfModel = tf.model({ inputs, outputs: output })
console.log("tfModel", tfModel)
console.log(tfModel.summary())
return tfModel
}
// ---> UPDATE WEIGHTS - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
updateWeights(weights: tf.Tensor[]): void {
console.log(weights)
for (const weight of weights) {
const layerIdAndWeightType: string[] = weight.name.split('/')
if (layerIdAndWeightType.length != 2) {
console.error('malformed weight name string: no "/"')
return
}
const weightType: 'kernel' | 'bias' = <'kernel' | 'bias'>(
layerIdAndWeightType[1].split('_')[0]
)
const layerId: number = parseInt(layerIdAndWeightType[0])
switch (weightType) {
case 'bias': {
this.host.network.tensorConfs.get(layerId).bias = <Float32Array>(
weight.dataSync()
)
break
}
case 'kernel': {
this.host.network.tensorConfs.get(layerId).weights = <Float32Array>(
weight.dataSync()
)
break
}
default:
console.error('malformed weight name string: weightType')
}
}
this.host.network.tensorConfs = new Map(this.host.network.tensorConfs)
}
}