keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
127 lines (102 loc) • 3.07 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _Layer = _interopRequireDefault(require("../../Layer"));
var _Tensor = _interopRequireDefault(require("../../Tensor"));
var _WebGL = require("../../WebGL2");
var _cwise = _interopRequireDefault(require("cwise"));
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
const programSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform float theta;\nout vec4 outColor;\n\nvoid main() {\n vec4 v = texture(x, vec2(outTex.x, outTex.y));\n outColor = v * float(greaterThan(v, vec4(theta)));\n}\n";
class ThresholdedReLU extends _Layer.default {
constructor(attrs = {}) {
super(attrs);
_initialiseProps.call(this);
this.layerClass = 'ThresholdedReLU';
const {
theta = 1
} = attrs;
this.description = `theta: ${theta}`;
this.theta = theta;
if (this.gpu) {
this.program = _WebGL.webgl2.compileProgram(programSource);
}
}
call(x) {
if (this.gpu) {
this._callGPU(x);
} else {
this._callCPU(x);
}
return this.output;
}
_callCPU(x) {
this.output = x;
this._compute(this.output.tensor, this.theta);
}
_callGPU(x) {
if (!x.glTexture && !x.glTextureFragments) {
x.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
}
if (!this.output) {
this.output = new _Tensor.default([], x.glTextureShape);
this.output.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
if (x.is1D) {
this.output.is1D = x.is1D;
} else if (x.is2DReshaped || x.is2DSquareReshaped) {
if (x.is2DReshaped) {
this.output.is2DReshaped = x.is2DReshaped;
} else if (x.is2DSquareReshaped) {
this.output.is2DSquareReshaped = x.is2DSquareReshaped;
}
this.output.originalShape = x.originalShape;
this.output.indicesForReshaped = x.indicesForReshaped;
}
}
_WebGL.webgl2.runProgram({
program: this.program,
output: this.output,
inputs: [{
input: x,
name: 'x'
}],
uniforms: [{
value: this.theta,
type: 'float',
name: 'theta'
}],
supportsTextureFragments: true
});
if (this.outbound.length === 0) {
this.output.transferFromGLTexture();
if (this.output.is2DReshaped) {
this.output.reshapeFrom2D();
} else if (this.output.is2DSquareReshaped) {
this.output.reshapeFrom2DSquare();
}
}
}
}
exports.default = ThresholdedReLU;
var _initialiseProps = function () {
Object.defineProperty(this, "_compute", {
configurable: true,
enumerable: true,
writable: true,
value: (0, _cwise.default)({
args: ['array', 'scalar'],
body: function (_x, theta) {
_x = _x * Number(_x > theta);
}
})
});
};