keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
159 lines (130 loc) • 5.54 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _Merge2 = _interopRequireDefault(require("./_Merge"));
var _Tensor = _interopRequireDefault(require("../../Tensor"));
var _WebGL = require("../../WebGL2");
var _ndarrayGemm = _interopRequireDefault(require("ndarray-gemm"));
var _ndarrayOps = _interopRequireDefault(require("ndarray-ops"));
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
const mergeProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D input1;\nuniform sampler2D input2;\nuniform int rows;\nuniform int cols;\nuniform int dotAxis1;\nuniform int dotAxis2;\nuniform int commonDim;\nuniform bool normalize;\nout vec4 outColor;\n\nvoid main() {\n int out_x = int(float(cols) * outTex.x);\n int out_y = int(float(rows) * outTex.y);\n\n float sum = 0.;\n float a = 0.;\n float b = 0.;\n float norm1 = 0.;\n float norm2 = 0.;\n\n for (int i = 0; i < commonDim; ++i) {\n if (dotAxis1 == 0 && dotAxis2 == 0) {\n a = texelFetch(input1, ivec2(out_y, i), 0).r;\n b = texelFetch(input2, ivec2(out_x, i), 0).r;\n } else if (dotAxis1 == 1 && dotAxis2 == 1) {\n a = texelFetch(input1, ivec2(i, out_y), 0).r;\n b = texelFetch(input2, ivec2(i, out_x), 0).r;\n }\n\n sum += a * b;\n\n if (normalize) {\n norm1 += a * a;\n norm2 += b * b;\n }\n }\n\n if (normalize) {\n sum /= sqrt(norm1) * sqrt(norm2);\n }\n\n outColor = vec4(sum);\n}\n";
class Dot extends _Merge2.default {
constructor(attrs = {}) {
super(attrs);
this.layerClass = 'Dot';
this.mode = 'dot';
const {
axes = -1,
normalize = false
} = attrs;
if (Array.isArray(axes)) {
this.dotAxes = [axes[0] <= 0 ? axes[0] : axes[0] - 1, axes[1] <= 0 ? axes[1] : axes[1] - 1];
} else {
this.dotAxes = [axes <= 0 ? axes : axes - 1, axes <= 0 ? axes : axes - 1];
}
this.normalize = normalize;
if (this.gpu) {
this.mergeProgram = _WebGL.webgl2.compileProgram(mergeProgramSource);
}
}
_calcOutputShape(inputShapes) {
let shape1 = inputShapes[0].slice();
let shape2 = inputShapes[1].slice();
shape1.splice(this.dotAxes[0], 1);
shape2.splice(this.dotAxes[1], 1);
this.outputShape = shape1.concat(shape2);
if (this.outputShape.length === 1) {
this.outputShape.push(1);
}
}
_callCPU(inputs) {
this._calcOutputShape([inputs[0].tensor.shape, inputs[1].tensor.shape]);
this.output = new _Tensor.default([], this.outputShape);
if (inputs[0].tensor.shape.length === 2 && inputs[1].tensor.shape.length === 2) {
if (this.dotAxes[0] === 0 && this.dotAxes[1] === 0) {
if (this.normalize) {
for (let i = 0; i < inputs[0].tensor.shape[1]; i++) {
_ndarrayOps.default.divseq(inputs[0].tensor.pick(null, i), _ndarrayOps.default.norm2(inputs[0].tensor.pick(null, i)));
}
for (let i = 0; i < inputs[1].tensor.shape[1]; i++) {
_ndarrayOps.default.divseq(inputs[1].tensor.pick(null, i), _ndarrayOps.default.norm2(inputs[1].tensor.pick(null, i)));
}
}
(0, _ndarrayGemm.default)(this.output.tensor, inputs[0].tensor.transpose(1, 0), inputs[1].tensor);
} else if (this.dotAxes[0] === 1 && this.dotAxes[1] === 1) {
if (this.normalize) {
for (let i = 0; i < inputs[0].tensor.shape[0]; i++) {
_ndarrayOps.default.divseq(inputs[0].tensor.pick(i, null), _ndarrayOps.default.norm2(inputs[0].tensor.pick(i, null)));
}
for (let i = 0; i < inputs[1].tensor.shape[0]; i++) {
_ndarrayOps.default.divseq(inputs[1].tensor.pick(i, null), _ndarrayOps.default.norm2(inputs[1].tensor.pick(i, null)));
}
}
(0, _ndarrayGemm.default)(this.output.tensor, inputs[0].tensor, inputs[1].tensor.transpose(1, 0));
}
} else {
this.throwError('dot mode for 3+ dim tensors not yet implemented.');
}
}
_callGPU(inputs) {
inputs.forEach(input => {
if (!input.glTexture && !input.glTextureFragments) {
input.createGLTexture({
type: '2d',
format: 'float'
});
}
});
this._calcOutputShape([inputs[0].glTextureShape, inputs[1].glTextureShape]);
if (!this.output) {
this.output = new _Tensor.default([], this.outputShape);
this.output.createGLTexture({
type: '2d',
format: 'float'
});
}
const commonDim = inputs[0].glTextureShape[this.dotAxes[0]];
_WebGL.webgl2.runProgram({
program: this.mergeProgram,
output: this.output,
inputs: [{
input: inputs[0],
name: 'input1'
}, {
input: inputs[1],
name: 'input2'
}],
uniforms: [{
value: this.output.glTextureShape[0],
type: 'int',
name: 'rows'
}, {
value: this.output.glTextureShape[1],
type: 'int',
name: 'cols'
}, {
value: this.dotAxes[0],
type: 'int',
name: 'dotAxis1'
}, {
value: this.dotAxes[1],
type: 'int',
name: 'dotAxis2'
}, {
value: commonDim,
type: 'int',
name: 'commonDim'
}, {
value: +this.normalize,
type: 'bool',
name: 'normalize'
}]
});
if (this.outbound.length === 0) {
this.output.transferFromGLTexture();
}
}
}
exports.default = Dot;