keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
111 lines (90 loc) • 2.72 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _Layer = _interopRequireDefault(require("../../Layer"));
var _Tensor = _interopRequireDefault(require("../../Tensor"));
var _WebGL = require("../../WebGL2");
var _activations = require("../../activations");
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
const programSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform float alpha;\nout vec4 outColor;\n\nvoid main() {\n vec4 v = texture(x, vec2(outTex.x, outTex.y));\n outColor = max(v, 0.0) + alpha * min(v, 0.0);\n}\n";
class LeakyReLU extends _Layer.default {
constructor(attrs = {}) {
super(attrs);
this.layerClass = 'LeakyReLU';
const {
alpha = 0.3
} = attrs;
this.description = `alpha: ${alpha}`;
this.alpha = alpha;
if (this.gpu) {
this.program = _WebGL.webgl2.compileProgram(programSource);
}
}
call(x) {
if (this.gpu) {
this._callGPU(x);
} else {
this._callCPU(x);
}
return this.output;
}
_callCPU(x) {
this.output = x;
(0, _activations.relu)(this.output, {
alpha: this.alpha
});
}
_callGPU(x) {
if (!x.glTexture && !x.glTextureFragments) {
x.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
}
if (!this.output) {
this.output = new _Tensor.default([], x.glTextureShape);
this.output.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
if (x.is1D) {
this.output.is1D = x.is1D;
} else if (x.is2DReshaped || x.is2DSquareReshaped) {
if (x.is2DReshaped) {
this.output.is2DReshaped = x.is2DReshaped;
} else if (x.is2DSquareReshaped) {
this.output.is2DSquareReshaped = x.is2DSquareReshaped;
}
this.output.originalShape = x.originalShape;
this.output.indicesForReshaped = x.indicesForReshaped;
}
}
_WebGL.webgl2.runProgram({
program: this.program,
output: this.output,
inputs: [{
input: x,
name: 'x'
}],
uniforms: [{
value: this.alpha,
type: 'float',
name: 'alpha'
}],
supportsTextureFragments: true
});
if (this.outbound.length === 0) {
this.output.transferFromGLTexture();
if (this.output.is2DReshaped) {
this.output.reshapeFrom2D();
} else if (this.output.is2DSquareReshaped) {
this.output.reshapeFrom2DSquare();
}
}
}
}
exports.default = LeakyReLU;