UNPKG

keras-js

Version:

Run Keras models in the browser, with GPU support using WebGL

128 lines (109 loc) 4.05 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.default = void 0; var _Layer = _interopRequireDefault(require("../../Layer")); var _Tensor = _interopRequireDefault(require("../../Tensor")); var _WebGL = require("../../WebGL2"); function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } const flattenProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform int outputSize;\nuniform int inputCols;\nout vec4 outColor;\n\nvoid main() {\n int out_x = int(float(outputSize) * outTex.x);\n int out_y = 0;\n\n int i = int(floor(float(out_x) / float(inputCols)));\n int j = int(mod(float(out_x), float(inputCols)));\n outColor = vec4(texelFetch(x, ivec2(j, i), 0).r);\n}\n"; const flattenFragmentsProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform int outputSize;\nuniform int inputRows;\nuniform int inputCols;\nout vec4 outColor;\n\nvoid main() {\n int out_x = int(float(outputSize) * outTex.x);\n int out_y = 0;\n\n int rowIndex = int(mod(floor(float(out_x) / float(inputCols)), float(inputRows)));\n int colIndex = int(mod(float(out_x), float(inputCols)));\n int fragmentIndex = int(floor(float(out_x) / (float(inputRows) * float(inputCols))));\n colIndex += fragmentIndex * inputCols;\n outColor = vec4(texelFetch(x, ivec2(colIndex, rowIndex), 0).r);\n}\n"; class Flatten extends _Layer.default { constructor(attrs = {}) { super(attrs); this.layerClass = 'Flatten'; if (this.gpu) { this.flattenProgram = _WebGL.webgl2.compileProgram(flattenProgramSource); this.flattenFragmentsProgram = _WebGL.webgl2.compileProgram(flattenFragmentsProgramSource); } } call(x) { if (this.gpu) { this._callGPU(x); } else { this._callCPU(x); } return this.output; } _callCPU(x) { if (x.tensor.shape.length <= 1) { this.output = x; } else { this.output = new _Tensor.default([], [x.tensor.shape.reduce((a, b) => a * b, 1)]); this.output.replaceTensorData(x.tensor.data); } } _callGPU(x) { if (!x.glTexture && !x.glTextureFragments) { if (x.tensor.shape.length <= 2) { x.createGLTexture({ type: '2d', format: 'float' }); } else if (x.tensor.shape.length > 2 && !x.is2DReshaped) { x.reshapeTo2D(); x.createGLTexture({ type: '2d', format: 'float' }); } } if (!this.output) { this.output = new _Tensor.default([], [x.glTextureShape.reduce((a, b) => a * b, 1)]); this.output.createGLTexture({ type: '2d', format: 'float' }); } if (x.glTextureFragments) { x.convert2DRowFragmentedGLTextureToColStack(); _WebGL.webgl2.runProgram({ program: this.flattenFragmentsProgram, output: this.output, inputs: [{ input: x, name: 'x' }], uniforms: [{ value: this.output.glTextureShape[1], type: 'int', name: 'outputSize' }, { value: x.glTextureShape[0], type: 'int', name: 'inputRows' }, { value: x.glTextureShape[1], type: 'int', name: 'inputCols' }], supportsTextureFragments: true }); x.removeGLTextureFragmentsAsColStack(); } else { _WebGL.webgl2.runProgram({ program: this.flattenProgram, output: this.output, inputs: [{ input: x, name: 'x' }], uniforms: [{ value: this.output.glTextureShape[1], type: 'int', name: 'outputSize' }, { value: x.glTextureShape[1], type: 'int', name: 'inputCols' }], supportsTextureFragments: true }); } if (this.outbound.length === 0) { this.output.transferFromGLTexture(); } } } exports.default = Flatten;