UNPKG

keras-js

Version:

Run Keras models in the browser, with GPU support using WebGL

159 lines (130 loc) 5.54 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.default = void 0; var _Merge2 = _interopRequireDefault(require("./_Merge")); var _Tensor = _interopRequireDefault(require("../../Tensor")); var _WebGL = require("../../WebGL2"); var _ndarrayGemm = _interopRequireDefault(require("ndarray-gemm")); var _ndarrayOps = _interopRequireDefault(require("ndarray-ops")); function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } const mergeProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D input1;\nuniform sampler2D input2;\nuniform int rows;\nuniform int cols;\nuniform int dotAxis1;\nuniform int dotAxis2;\nuniform int commonDim;\nuniform bool normalize;\nout vec4 outColor;\n\nvoid main() {\n int out_x = int(float(cols) * outTex.x);\n int out_y = int(float(rows) * outTex.y);\n\n float sum = 0.;\n float a = 0.;\n float b = 0.;\n float norm1 = 0.;\n float norm2 = 0.;\n\n for (int i = 0; i < commonDim; ++i) {\n if (dotAxis1 == 0 && dotAxis2 == 0) {\n a = texelFetch(input1, ivec2(out_y, i), 0).r;\n b = texelFetch(input2, ivec2(out_x, i), 0).r;\n } else if (dotAxis1 == 1 && dotAxis2 == 1) {\n a = texelFetch(input1, ivec2(i, out_y), 0).r;\n b = texelFetch(input2, ivec2(i, out_x), 0).r;\n }\n\n sum += a * b;\n\n if (normalize) {\n norm1 += a * a;\n norm2 += b * b;\n }\n }\n\n if (normalize) {\n sum /= sqrt(norm1) * sqrt(norm2);\n }\n\n outColor = vec4(sum);\n}\n"; class Dot extends _Merge2.default { constructor(attrs = {}) { super(attrs); this.layerClass = 'Dot'; this.mode = 'dot'; const { axes = -1, normalize = false } = attrs; if (Array.isArray(axes)) { this.dotAxes = [axes[0] <= 0 ? axes[0] : axes[0] - 1, axes[1] <= 0 ? axes[1] : axes[1] - 1]; } else { this.dotAxes = [axes <= 0 ? axes : axes - 1, axes <= 0 ? axes : axes - 1]; } this.normalize = normalize; if (this.gpu) { this.mergeProgram = _WebGL.webgl2.compileProgram(mergeProgramSource); } } _calcOutputShape(inputShapes) { let shape1 = inputShapes[0].slice(); let shape2 = inputShapes[1].slice(); shape1.splice(this.dotAxes[0], 1); shape2.splice(this.dotAxes[1], 1); this.outputShape = shape1.concat(shape2); if (this.outputShape.length === 1) { this.outputShape.push(1); } } _callCPU(inputs) { this._calcOutputShape([inputs[0].tensor.shape, inputs[1].tensor.shape]); this.output = new _Tensor.default([], this.outputShape); if (inputs[0].tensor.shape.length === 2 && inputs[1].tensor.shape.length === 2) { if (this.dotAxes[0] === 0 && this.dotAxes[1] === 0) { if (this.normalize) { for (let i = 0; i < inputs[0].tensor.shape[1]; i++) { _ndarrayOps.default.divseq(inputs[0].tensor.pick(null, i), _ndarrayOps.default.norm2(inputs[0].tensor.pick(null, i))); } for (let i = 0; i < inputs[1].tensor.shape[1]; i++) { _ndarrayOps.default.divseq(inputs[1].tensor.pick(null, i), _ndarrayOps.default.norm2(inputs[1].tensor.pick(null, i))); } } (0, _ndarrayGemm.default)(this.output.tensor, inputs[0].tensor.transpose(1, 0), inputs[1].tensor); } else if (this.dotAxes[0] === 1 && this.dotAxes[1] === 1) { if (this.normalize) { for (let i = 0; i < inputs[0].tensor.shape[0]; i++) { _ndarrayOps.default.divseq(inputs[0].tensor.pick(i, null), _ndarrayOps.default.norm2(inputs[0].tensor.pick(i, null))); } for (let i = 0; i < inputs[1].tensor.shape[0]; i++) { _ndarrayOps.default.divseq(inputs[1].tensor.pick(i, null), _ndarrayOps.default.norm2(inputs[1].tensor.pick(i, null))); } } (0, _ndarrayGemm.default)(this.output.tensor, inputs[0].tensor, inputs[1].tensor.transpose(1, 0)); } } else { this.throwError('dot mode for 3+ dim tensors not yet implemented.'); } } _callGPU(inputs) { inputs.forEach(input => { if (!input.glTexture && !input.glTextureFragments) { input.createGLTexture({ type: '2d', format: 'float' }); } }); this._calcOutputShape([inputs[0].glTextureShape, inputs[1].glTextureShape]); if (!this.output) { this.output = new _Tensor.default([], this.outputShape); this.output.createGLTexture({ type: '2d', format: 'float' }); } const commonDim = inputs[0].glTextureShape[this.dotAxes[0]]; _WebGL.webgl2.runProgram({ program: this.mergeProgram, output: this.output, inputs: [{ input: inputs[0], name: 'input1' }, { input: inputs[1], name: 'input2' }], uniforms: [{ value: this.output.glTextureShape[0], type: 'int', name: 'rows' }, { value: this.output.glTextureShape[1], type: 'int', name: 'cols' }, { value: this.dotAxes[0], type: 'int', name: 'dotAxis1' }, { value: this.dotAxes[1], type: 'int', name: 'dotAxis2' }, { value: commonDim, type: 'int', name: 'commonDim' }, { value: +this.normalize, type: 'bool', name: 'normalize' }] }); if (this.outbound.length === 0) { this.output.transferFromGLTexture(); } } } exports.default = Dot;