UNPKG

keras-js

Version:

Run Keras models in the browser, with GPU support using WebGL

145 lines (116 loc) 4.63 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.default = void 0; var _Layer = _interopRequireDefault(require("../../Layer")); var _Tensor = _interopRequireDefault(require("../../Tensor")); var activations = _interopRequireWildcard(require("../../activations")); var _WebGL = require("../../WebGL2"); var _ndarrayBlasLevel = require("ndarray-blas-level2"); var _ndarrayOps = _interopRequireDefault(require("ndarray-ops")); var activationProgramSources = _interopRequireWildcard(require("../../activations/programSources")); function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = Object.defineProperty && Object.getOwnPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : {}; if (desc.get || desc.set) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj.default = obj; return newObj; } } function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } const matMulProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D A;\nuniform sampler2D B;\nuniform sampler2D C;\nuniform bool addC;\nout vec4 outColor;\n\nvoid main() {\n ivec2 A_size = textureSize(A, 0);\n ivec2 B_size = textureSize(B, 0);\n int out_x = int(float(B_size[0]) * outTex.x);\n int out_y = int(float(A_size[1]) * outTex.y);\n int commonDim = A_size[0];\n\n float sum = 0.;\n for (int i = 0; i < commonDim; ++i) {\n float a = texelFetch(A, ivec2(i, out_y), 0).r;\n float b = texelFetch(B, ivec2(out_x, i), 0).r;\n sum += a * b;\n }\n\n if (addC) {\n sum += texelFetch(C, ivec2(out_x, 0), 0).r;\n }\n\n outColor = vec4(sum);\n}\n"; class Dense extends _Layer.default { constructor(attrs = {}) { super(attrs); this.layerClass = 'Dense'; const { units = 1, activation = 'linear', input_dim = null, use_bias = true } = attrs; this.description = `${activation} activation, output dimensions: ${units}`; this.activation = activation; this.activationFunc = activations[this.activation]; this.units = units; this.input_dim = input_dim; this.use_bias = use_bias; this.params = this.use_bias ? ['kernel', 'bias'] : ['kernel']; if (this.input_dim) { this.inputShape = [this.input_dim]; } if (this.gpu) { this.matMulProgram = _WebGL.webgl2.compileProgram(matMulProgramSource); this.activationProgram = _WebGL.webgl2.compileProgram(activationProgramSources[this.activation]); } } call(x) { if (this.gpu) { this._callGPU(x); } else { this._callCPU(x); } return this.output; } _callCPU(x) { this.output = new _Tensor.default([], [this.units]); if (this.use_bias) { _ndarrayOps.default.assign(this.output.tensor, this.weights['bias'].tensor); } (0, _ndarrayBlasLevel.gemv)(1, this.weights['kernel'].tensor.transpose(1, 0), x.tensor, 1, this.output.tensor); this.activationFunc(this.output); } _callGPU(x) { if (!x.glTexture) { x.createGLTexture({ type: '2d', format: 'float' }); } if (this.activation !== 'linear' && !this.outputPreactiv) { this.outputPreactiv = new _Tensor.default([], [this.units]); this.outputPreactiv.createGLTexture({ type: '2d', format: 'float' }); } if (!this.output) { this.output = new _Tensor.default([], [this.units]); this.output.createGLTexture({ type: '2d', format: 'float' }); } const matMulInputs = [{ input: x, name: 'A' }, { input: this.weights['kernel'], name: 'B' }]; if (this.use_bias) { matMulInputs.push({ input: this.weights['bias'], name: 'C' }); } _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.activation === 'linear' ? this.output : this.outputPreactiv, inputs: matMulInputs, uniforms: [{ value: this.use_bias ? 1 : 0, type: 'bool', name: 'addC' }] }); if (this.activation !== 'linear') { _WebGL.webgl2.runProgram({ program: this.activationProgram, output: this.output, inputs: [{ input: this.outputPreactiv, name: 'x' }] }); } if (this.outbound.length === 0) { this.output.transferFromGLTexture(); } } } exports.default = Dense;