@webwriter/neural-network
Version:
Deep learning visualization for feed-forward networks with custom datasets, training and prediction.
140 lines (121 loc) • 4.73 kB
text/typescript
import { LitElementWw } from '@webwriter/lit'
import { TemplateResult, html } from 'lit'
import { customElement, queryAll, property } from 'lit/decorators.js'
import { consume } from '@lit/context'
import type { CLayerConf } from '@/types/c_layer_conf'
import type { TensorConf } from '@/types/tensor_conf'
import type { CLayerConnectionConf } from '@/types/c_layer_connection_conf'
import { layerConfsContext } from '@/contexts/layer_confs_context'
import { layerConnectionConfsContext } from '@/contexts/layer_con_confs_context'
import { CLayer } from '@/components/network/c_layer'
import { InputLayer } from '@/components/network/input_layer'
import { OutputLayer } from '@/components/network/output_layer'
import { CLayerConnection } from '@/components/network/c_layer_connection'
import { DenseLayer } from '@/components/network/dense_layer'
import '@/components/network/input_layer'
import '@/components/network/dense_layer'
import '@/components/network/output_layer'
import '@/components/network/c_layer'
import '@/components/network/c_layer_connection'
import { CEdge } from './c_edge'
import { CNeuron } from './neuron'
export class CNetwork extends LitElementWw {
static scopedElements = {
"c-edge": CEdge,
"c-layer-connection": CLayerConnection,
"dense-layer": DenseLayer,
"input-layer": InputLayer,
"output-layer": OutputLayer,
"c-neuron": CNeuron
}
accessor layerConfs: CLayerConf[]
accessor layerConnectionConfs: CLayerConnectionConf[]
accessor tensorConfs: Map<number, TensorConf> = new Map()
accessor _layers: NodeListOf<CLayer>
accessor _layerConnections: NodeListOf<CLayerConnection>
// METHODS - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// -> GETTING - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// get a layer by its layerId
getLayerById(layerId: number): CLayer {
return Array.from(this._layers).find(
(layer) => layer.conf.layerId == layerId
)
}
// get the input layers
getInputLayers(): InputLayer[] {
return <InputLayer[]>(
this.layerConfs
.map((layerConf) => this.getLayerById(layerConf.layerId))
.filter((layer) => layer instanceof InputLayer)
)
}
// get the output layer
getOutputLayer(): OutputLayer {
return <OutputLayer>(
this.layerConfs
.map((layerConf) => this.getLayerById(layerConf.layerId))
.find((layer) => layer instanceof OutputLayer)
)
}
// get a layer connection by its id
getLayerConnectionByLayerIds(
sourceLayerId: number,
targetLayerId: number
): CLayerConnection {
return Array.from(this._layerConnections).find((layerConnection) => {
return (
layerConnection.conf.sourceLayerId == sourceLayerId &&
layerConnection.conf.targetLayerId == targetLayerId
)
})
}
// get the target (subsequent) layers of a layer
getTargetsFor(source: CLayer): CLayer[] {
return this.layerConnectionConfs
.filter((conConf) => conConf.sourceLayerId == source.conf.layerId)
.map((conConf) => this.getLayerById(conConf.targetLayerId))
}
// tet the source (preceding) layers of a layer
getSourcesFor(target: CLayer): CLayer[] {
return this.layerConnectionConfs
.filter((conConf) => conConf.targetLayerId == target.conf.layerId)
.map((conConf) => this.getLayerById(conConf.sourceLayerId))
}
// RENDER - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
getHTMLForLayerConf(layerConf: CLayerConf) {
const layer = (this.shadowRoot as any).createElement(layerConf.HTML_TAG)
layer.conf = layerConf
const tensorConf = this.tensorConfs?.get(layerConf.layerId)
layer.tensor = tensorConf?.tensor
layer.bias = tensorConf?.bias
layer.weights = tensorConf?.weights
layer.classList.add('layer')
return layer
}
getHTMLForLayerConnectionConf(layerConnectionConf: CLayerConnectionConf) {
const layerConnection = <CLayerConnection>(
(this.shadowRoot as any).createElement('c-layer-connection')
)
layerConnection.conf = layerConnectionConf
return layerConnection
}
render(): TemplateResult<1> {
return html`
<div id="layers">
${this.layerConfs.map((layerConf) =>
this.getHTMLForLayerConf(layerConf)
)}
</div>
<div id="layerConnections">
${this.layerConnectionConfs.map((layerConnectionConf) =>
this.getHTMLForLayerConnectionConf(layerConnectionConf)
)}
</div>
`
}
}