UNPKG

keras-js

Version:

Run Keras models in the browser, with GPU support using WebGL

112 lines (89 loc) 3.18 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.default = void 0; var _Layer = _interopRequireDefault(require("../../Layer")); var _Tensor = _interopRequireDefault(require("../../Tensor")); var _WebGL = require("../../WebGL2"); var activations = _interopRequireWildcard(require("../../activations")); var activationProgramSources = _interopRequireWildcard(require("../../activations/programSources")); function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = Object.defineProperty && Object.getOwnPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : {}; if (desc.get || desc.set) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj.default = obj; return newObj; } } function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } class Activation extends _Layer.default { constructor(attrs = {}) { super(attrs); this.layerClass = 'Activation'; const { activation = 'linear' } = attrs; this.activation = activation; this.activationFunc = activations[activation]; this.description = `${this.activation}`; if (this.gpu) { this.program = _WebGL.webgl2.compileProgram(activationProgramSources[this.activation]); } } call(x) { if (this.activation === 'linear') { this.output = x; return this.output; } if (this.gpu) { this._callGPU(x); } else { this._callCPU(x); } return this.output; } _callCPU(x) { this.output = new _Tensor.default(new x.arrayType(x.tensor.data), x.tensor.shape); this.activationFunc(this.output); } _callGPU(x) { if (!x.glTexture && !x.glTextureFragments) { x.createGLTexture({ type: '2d', format: 'float', supportsTextureFragments: true }); } if (!this.output) { this.output = new _Tensor.default([], x.glTextureShape); this.output.createGLTexture({ type: '2d', format: 'float', supportsTextureFragments: true }); if (x.is1D) { this.output.is1D = x.is1D; } else if (x.is2DReshaped || x.is2DSquareReshaped) { if (x.is2DReshaped) { this.output.is2DReshaped = x.is2DReshaped; } else if (x.is2DSquareReshaped) { this.output.is2DSquareReshaped = x.is2DSquareReshaped; } this.output.originalShape = x.originalShape; this.output.indicesForReshaped = x.indicesForReshaped; } } _WebGL.webgl2.runProgram({ program: this.program, output: this.output, inputs: [{ input: x, name: 'x' }], supportsTextureFragments: true }); if (this.outbound.length === 0) { this.output.transferFromGLTexture(); if (this.output.is2DReshaped) { this.output.reshapeFrom2D(); } else if (this.output.is2DSquareReshaped) { this.output.reshapeFrom2DSquare(); } } } } exports.default = Activation;