keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
112 lines (89 loc) • 3.18 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _Layer = _interopRequireDefault(require("../../Layer"));
var _Tensor = _interopRequireDefault(require("../../Tensor"));
var _WebGL = require("../../WebGL2");
var activations = _interopRequireWildcard(require("../../activations"));
var activationProgramSources = _interopRequireWildcard(require("../../activations/programSources"));
function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = Object.defineProperty && Object.getOwnPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : {}; if (desc.get || desc.set) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj.default = obj; return newObj; } }
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
class Activation extends _Layer.default {
constructor(attrs = {}) {
super(attrs);
this.layerClass = 'Activation';
const {
activation = 'linear'
} = attrs;
this.activation = activation;
this.activationFunc = activations[activation];
this.description = `${this.activation}`;
if (this.gpu) {
this.program = _WebGL.webgl2.compileProgram(activationProgramSources[this.activation]);
}
}
call(x) {
if (this.activation === 'linear') {
this.output = x;
return this.output;
}
if (this.gpu) {
this._callGPU(x);
} else {
this._callCPU(x);
}
return this.output;
}
_callCPU(x) {
this.output = new _Tensor.default(new x.arrayType(x.tensor.data), x.tensor.shape);
this.activationFunc(this.output);
}
_callGPU(x) {
if (!x.glTexture && !x.glTextureFragments) {
x.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
}
if (!this.output) {
this.output = new _Tensor.default([], x.glTextureShape);
this.output.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
if (x.is1D) {
this.output.is1D = x.is1D;
} else if (x.is2DReshaped || x.is2DSquareReshaped) {
if (x.is2DReshaped) {
this.output.is2DReshaped = x.is2DReshaped;
} else if (x.is2DSquareReshaped) {
this.output.is2DSquareReshaped = x.is2DSquareReshaped;
}
this.output.originalShape = x.originalShape;
this.output.indicesForReshaped = x.indicesForReshaped;
}
}
_WebGL.webgl2.runProgram({
program: this.program,
output: this.output,
inputs: [{
input: x,
name: 'x'
}],
supportsTextureFragments: true
});
if (this.outbound.length === 0) {
this.output.transferFromGLTexture();
if (this.output.is2DReshaped) {
this.output.reshapeFrom2D();
} else if (this.output.is2DSquareReshaped) {
this.output.reshapeFrom2DSquare();
}
}
}
}
exports.default = Activation;