UNPKG

keras-js

Version:

Run Keras models in the browser, with GPU support using WebGL

462 lines (371 loc) 18.7 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.default = void 0; var _Layer = _interopRequireDefault(require("../../Layer")); var _Tensor = _interopRequireDefault(require("../../Tensor")); var activations = _interopRequireWildcard(require("../../activations")); var _WebGL = require("../../WebGL2"); var _createGLSLProgram = _interopRequireDefault(require("../../webgl/dynamic/createGLSLProgram")); var tensorUtils = _interopRequireWildcard(require("../../utils/tensorUtils")); var _ndarrayOps = _interopRequireDefault(require("ndarray-ops")); var _ndarrayGemm = _interopRequireDefault(require("ndarray-gemm")); var activationProgramSources = _interopRequireWildcard(require("../../activations/programSources")); function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = Object.defineProperty && Object.getOwnPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : {}; if (desc.get || desc.set) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj.default = obj; return newObj; } } function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } const mapInputProgramSource = "#version 300 es\nprecision highp float;\nprecision highp isampler2D;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform isampler2D indexMap;\nuniform int inputCols;\nout vec4 outColor;\n\nvoid main() {\n ivec2 size = textureSize(indexMap, 0);\n int out_x = int(float(size[0]) * outTex.x);\n int out_y = int(float(size[1]) * outTex.y);\n\n int index = texelFetch(indexMap, ivec2(out_x, out_y), 0).r;\n\n if (index != -1) {\n int rowIndex = int(floor(float(index) / float(inputCols)));\n int colIndex = int(mod(float(index), float(inputCols)));\n float val = texelFetch(x, ivec2(colIndex, rowIndex), 0).r;\n outColor = vec4(val);\n } else {\n outColor = vec4(0.0);\n }\n}\n"; const mapInputFragmentsProgramSource = "#version 300 es\nprecision highp float;\nprecision highp isampler2D;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform isampler2D indexMap;\nuniform int inputCols;\nout vec4 outColor;\n\nvoid main() {\n ivec2 inputSize = textureSize(x, 0);\n ivec2 outputSize = textureSize(indexMap, 0);\n int out_x = int(float(outputSize[0]) * outTex.x);\n int out_y = int(float(outputSize[1]) * outTex.y);\n\n int index = texelFetch(indexMap, ivec2(out_x, out_y), 0).r;\n\n if (index != -1) {\n int rowIndex = int(floor(float(index) / float(inputCols)));\n int colIndex = int(mod(float(index), float(inputCols)));\n int fragmentIndex = int(floor(float(rowIndex) / float(inputSize[1])));\n rowIndex = int(mod(float(rowIndex), float(inputSize[1])));\n colIndex = fragmentIndex * inputCols + colIndex;\n float val = texelFetch(x, ivec2(colIndex, rowIndex), 0).r;\n outColor = vec4(val);\n } else {\n outColor = vec4(0.0);\n }\n}\n"; const matMulProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D A;\nuniform sampler2D B;\nuniform sampler2D C;\nuniform bool addC;\nout vec4 outColor;\n\nvoid main() {\n ivec2 A_size = textureSize(A, 0);\n ivec2 B_size = textureSize(B, 0);\n int out_x = int(float(B_size[0]) * outTex.x);\n int out_y = int(float(A_size[1]) * outTex.y);\n int commonDim = A_size[0];\n\n float sum = 0.;\n for (int i = 0; i < commonDim; ++i) {\n float a = texelFetch(A, ivec2(i, out_y), 0).r;\n float b = texelFetch(B, ivec2(out_x, i), 0).r;\n sum += a * b;\n }\n\n if (addC) {\n sum += texelFetch(C, ivec2(out_x, 0), 0).r;\n }\n\n outColor = vec4(sum);\n}\n"; class Conv2D extends _Layer.default { constructor(attrs = {}) { super(attrs); this.layerClass = 'Conv2D'; const { filters = 1, kernel_size = [3, 3], strides = [1, 1], padding = 'valid', data_format = 'channels_last', dilation_rate = [1, 1], activation = 'linear', use_bias = true } = attrs; if (Array.isArray(kernel_size)) { this.kernelShape = [filters, ...kernel_size]; } else { this.kernelShape = [filters, kernel_size, kernel_size]; } if (Array.isArray(strides)) { this.strides = strides; } else { this.strides = [strides, strides]; } if (padding === 'valid' || padding === 'same') { this.padding = padding; } else { this.throwError('Invalid padding.'); } if (data_format === 'channels_last' || data_format === 'channels_first') { this.dataFormat = data_format; } else { this.throwError('Only channels_last and channels_first data formats are allowed.'); } if (Array.isArray(dilation_rate)) { this.dilationRate = dilation_rate; } else { this.dilationRate = [dilation_rate, dilation_rate]; } if ((this.dilationRate[0] !== 1 || this.dilationRate[1] !== 1) && (this.strides[0] !== 1 || this.strides[1] !== 1)) { this.throwError(`Incompatible combination of dilation_rate with strides.`); } this.activation = activation; this.activationFunc = activations[activation]; this.useBias = use_bias; this.params = this.useBias ? ['kernel', 'bias'] : ['kernel']; this.description = `${this.kernelShape[0]} ${this.kernelShape.slice(1).join('x')} filters`; this.description += this.strides.some(s => s > 1) ? `, ${this.strides.join('x')} striding` : ''; this.description += this.padding === 'valid' ? `, no border padding` : ', pad to same borders'; this.description += this.dilationRate.some(r => r > 1) ? `, ${this.dilationRate.join('x')} dilation` : ''; this.description += this.activation !== 'linear' ? `, ${this.activation} activation` : ''; if (this.gpu) { this.mapInputProgram = _WebGL.webgl2.compileProgram(mapInputProgramSource); this.mapInputFragmentsProgram = _WebGL.webgl2.compileProgram(mapInputFragmentsProgramSource); this.matMulProgram = _WebGL.webgl2.compileProgram(matMulProgramSource); this.activationProgram = _WebGL.webgl2.compileProgram(activationProgramSources[this.activation]); } } setWeights(weightsArr) { if (this.dataFormat === 'channels_first') { weightsArr[0].tensor = weightsArr[0].tensor.transpose(2, 3, 1, 0); } super.setWeights(weightsArr, false); this._w2row(); if (this.gpu) { this.weights['kernel'] = this.wRowsMat; this.weights['kernel'].createGLTexture({ type: '2d', format: 'float' }); if (this.useBias) { this.weights['bias'].createGLTexture({ type: '2d', format: 'float' }); } } } call(x) { if (this.gpu) { this._callGPU(x); } else { this._callCPU(x); } return this.output; } _calcOutputShape(inputShape) { if (this.outputShape && this.inputPadding) { return; } const inputRows = inputShape[0]; const inputCols = inputShape[1]; const [nbFilter, nbRow, nbCol] = this.kernelShape; const nbRowDilated = nbRow + (nbRow - 1) * (this.dilationRate[0] - 1); const nbColDilated = nbCol + (nbCol - 1) * (this.dilationRate[1] - 1); const outputRows = this.padding === 'same' ? Math.floor((inputRows + this.strides[0] - 1) / this.strides[0]) : Math.floor((inputRows - nbRowDilated + this.strides[0]) / this.strides[0]); const outputCols = this.padding === 'same' ? Math.floor((inputCols + this.strides[1] - 1) / this.strides[1]) : Math.floor((inputCols - nbColDilated + this.strides[1]) / this.strides[1]); const outputChannels = nbFilter; const paddingRow = this.padding === 'same' ? Math.max(0, Math.floor((outputRows - 1) * this.strides[0] + nbRowDilated - inputRows)) : 0; const paddingCol = this.padding === 'same' ? Math.max(0, Math.floor((outputCols - 1) * this.strides[1] + nbColDilated - inputCols)) : 0; const paddingRowBefore = Math.floor(paddingRow / 2); const paddingRowAfter = paddingRow - paddingRowBefore; const paddingColBefore = Math.floor(paddingCol / 2); const paddingColAfter = paddingCol - paddingColBefore; this.outputShape = [outputRows, outputCols, outputChannels]; this.inputPadding = [paddingRowBefore, paddingRowAfter, paddingColBefore, paddingColAfter]; } _padInput(x, padValue = 0) { if (this.padding === 'same') { const [inputRows, inputCols, inputChannels] = x.tensor.shape; const [paddingRowBefore, paddingRowAfter, paddingColBefore, paddingColAfter] = this.inputPadding; const newRows = inputRows + paddingRowBefore + paddingRowAfter; const newCols = inputCols + paddingColBefore + paddingColAfter; const _x = new _Tensor.default([], [newRows, newCols, inputChannels]); if (padValue !== 0) { _ndarrayOps.default.assigns(_x.tensor, padValue); } _ndarrayOps.default.assign(_x.tensor.hi(inputRows + paddingRowBefore, inputCols + paddingColBefore, inputChannels).lo(paddingRowBefore, paddingColBefore, 0), x.tensor); return _x; } return x; } _im2col(x) { const [inputRows, inputCols, inputChannels] = x.tensor.shape; const nbRow = this.kernelShape[1]; const nbCol = this.kernelShape[2]; const outputRows = this.outputShape[0]; const outputCols = this.outputShape[1]; const nbPatches = outputRows * outputCols; const patchLen = nbRow * nbCol * inputChannels; const nbRowDilated = nbRow + (nbRow - 1) * (this.dilationRate[0] - 1); const nbColDilated = nbCol + (nbCol - 1) * (this.dilationRate[1] - 1); if (!this.imColsMat) { this.imColsMat = new _Tensor.default([], [nbPatches, patchLen]); } if (nbRowDilated === 1 && nbColDilated === 1 && this.strides[0] === 1 && this.strides[1] === 1) { this.imColsMat.replaceTensorData(x.tensor.data); return this.imColsMat; } const patch = new _Tensor.default([], [nbRow, nbCol, inputChannels]); let offset = 0; for (let i = 0, limit = inputRows - nbRowDilated; i <= limit; i += this.strides[0]) { for (let j = 0, limit = inputCols - nbColDilated; j <= limit; j += this.strides[1]) { _ndarrayOps.default.assign(patch.tensor, x.tensor.hi(i + nbRowDilated, j + nbColDilated, inputChannels).lo(i, j, 0).step(this.dilationRate[0], this.dilationRate[1], 1)); this.imColsMat.tensor.data.set(patch.tensor.data, offset); offset += patchLen; } } return this.imColsMat; } _w2row() { const inputChannels = this.weights['kernel'].tensor.shape[2]; const [nbFilter, nbRow, nbCol] = this.kernelShape; const patchLen = nbRow * nbCol * inputChannels; this.wRowsMat = new _Tensor.default([], [patchLen, nbFilter]); const patch = new _Tensor.default([], [nbRow, nbCol, inputChannels]); const patchRaveled = new _Tensor.default([], [patchLen]); for (let n = 0; n < nbFilter; n++) { _ndarrayOps.default.assign(patch.tensor, this.weights['kernel'].tensor.pick(null, null, null, n)); patchRaveled.replaceTensorData(patch.tensor.data); _ndarrayOps.default.assign(this.wRowsMat.tensor.pick(null, n), patchRaveled.tensor); } return this.wRowsMat; } _callCPU(x) { this.inputShape = x.tensor.shape; this._calcOutputShape(this.inputShape); x = this._padInput(x); this._im2col(x); const nbFilter = this.kernelShape[0]; const outputRows = this.outputShape[0]; const outputCols = this.outputShape[1]; const nbPatches = outputRows * outputCols; const matMul = new _Tensor.default([], [nbPatches, nbFilter]); if (this.useBias) { for (let n = 0; n < nbFilter; n++) { _ndarrayOps.default.assigns(matMul.tensor.pick(null, n), this.weights['bias'].tensor.get(n)); } } (0, _ndarrayGemm.default)(matMul.tensor, this.imColsMat.tensor, this.wRowsMat.tensor, 1, 1); this.output = new _Tensor.default([], this.outputShape); let outputChannelRaveled = new _Tensor.default([], [outputRows * outputCols]); let outputChannel = new _Tensor.default([], [outputRows, outputCols]); for (let n = 0; n < nbFilter; n++) { _ndarrayOps.default.assign(outputChannelRaveled.tensor, matMul.tensor.pick(null, n)); outputChannel.replaceTensorData(outputChannelRaveled.tensor.data); _ndarrayOps.default.assign(this.output.tensor.pick(null, null, n), outputChannel.tensor); } this.activationFunc(this.output); if (this.dataFormat === 'channels_first') { this.output.tensor = this.output.tensor.transpose(2, 0, 1); } } _createIndexMap(indicesForReshaped) { if (this.indexMap) { return; } let [inputRows, inputCols, inputChannels] = this.inputShape; let indices = new _Tensor.default(indicesForReshaped.data, indicesForReshaped.shape, { type: Int32Array }); if (this.padding === 'same') { const [paddingRowBefore, paddingRowAfter, paddingColBefore, paddingColAfter] = this.inputPadding; inputRows = inputRows + paddingRowBefore + paddingRowAfter; inputCols = inputCols + paddingColBefore + paddingColAfter; const padValue = -1; indices = this._padInput(indices, padValue); } const nbRow = this.kernelShape[1]; const nbCol = this.kernelShape[2]; const outputRows = this.outputShape[0]; const outputCols = this.outputShape[1]; const nbPatches = outputRows * outputCols; const patchLen = nbRow * nbCol * inputChannels; const nbRowDilated = nbRow + (nbRow - 1) * (this.dilationRate[0] - 1); const nbColDilated = nbCol + (nbCol - 1) * (this.dilationRate[1] - 1); this.indexMap = new _Tensor.default([], [nbPatches, patchLen], { type: Int32Array }); const indicesPatch = new _Tensor.default([], [nbRow, nbCol, inputChannels]); let offset = 0; for (let i = 0, limit = inputRows - nbRowDilated; i <= limit; i += this.strides[0]) { for (let j = 0, limit = inputCols - nbColDilated; j <= limit; j += this.strides[1]) { _ndarrayOps.default.assign(indicesPatch.tensor, indices.tensor.hi(i + nbRowDilated, j + nbColDilated, inputChannels).lo(i, j, 0).step(this.dilationRate[0], this.dilationRate[1], 1)); this.indexMap.tensor.data.set(indicesPatch.tensor.data, offset); offset += patchLen; } } this.indexMap.createGLTexture({ type: '2d', format: 'int', supportsTextureFragments: true }); } _callGPU(x) { let outputTextureShape; if (x.is2DReshaped || x.is2DSquareReshaped) { this.inputShape = x.originalShape; this._calcOutputShape(this.inputShape); this._createIndexMap(x.indicesForReshaped); outputTextureShape = [this.indexMap.glTextureShape[0], this.weights['kernel'].glTextureShape[1]]; } else { this.inputShape = x.tensor.shape; this._calcOutputShape(this.inputShape); x = this._padInput(x); this._im2col(x); this.imColsMat.createGLTexture({ type: '2d', format: 'float', supportsTextureFragments: true }); outputTextureShape = [this.imColsMat.glTextureShape[0], this.weights['kernel'].glTextureShape[1]]; } if (this.activation !== 'linear' && !this.outputPreactiv) { this.outputPreactiv = new _Tensor.default([], outputTextureShape); this.outputPreactiv.createGLTexture({ type: '2d', format: 'float', supportsTextureFragments: true }); this.outputPreactiv.is2DReshaped = true; this.outputPreactiv.originalShape = this.outputShape; this.outputPreactiv.indicesForReshaped = tensorUtils.createIndicesFor2DReshaped(this.outputShape, false, -1); } if (!this.output) { this.output = new _Tensor.default([], outputTextureShape); this.output.createGLTexture({ type: '2d', format: 'float', supportsTextureFragments: true }); this.output.is2DReshaped = true; this.output.originalShape = this.outputShape; this.output.indicesForReshaped = tensorUtils.createIndicesFor2DReshaped(this.outputShape, false, -1); } if (x.is2DReshaped || x.is2DSquareReshaped) { const hasFragments = Boolean(x.glTextureFragments); if (hasFragments) { x.convert2DRowFragmentedGLTextureToColStack(); } if (!this.convProgram) { const convProgramSource = (0, _createGLSLProgram.default)('conv2d', this.output.glTextureFragmentShape ? this.output.glTextureFragmentShape : this.output.glTextureShape, x.glTextureFragmentShape ? x.glTextureFragmentShape : x.glTextureShape, this.indexMap.glTextureFragmentShape ? this.indexMap.glTextureFragmentShape : this.indexMap.glTextureShape, this.useBias, hasFragments); this.convProgram = _WebGL.webgl2.compileProgram(convProgramSource); } _WebGL.webgl2.runProgram({ program: this.convProgram, output: this.activation === 'linear' ? this.output : this.outputPreactiv, inputs: [{ input: x, name: 'x' }, { input: this.indexMap, name: 'indexMap' }, { input: this.weights['kernel'], name: 'kernel' }, ...(this.useBias ? [{ input: this.weights['bias'], name: 'bias' }] : [])], supportsTextureFragments: true }); if (hasFragments) { x.removeGLTextureFragmentsAsColStack(); } } else { const matMulInputs = [{ input: this.imColsMat, name: 'A' }, { input: this.weights['kernel'], name: 'B' }]; if (this.useBias) { matMulInputs.push({ input: this.weights['bias'], name: 'C' }); } _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.activation === 'linear' ? this.output : this.outputPreactiv, inputs: matMulInputs, uniforms: [{ value: this.useBias ? 1 : 0, type: 'bool', name: 'addC' }], supportsTextureFragments: true }); } if (this.activation !== 'linear') { _WebGL.webgl2.runProgram({ program: this.activationProgram, output: this.output, inputs: [{ input: this.outputPreactiv, name: 'x' }], supportsTextureFragments: true }); } if (this.outbound.length === 0) { this.output.transferFromGLTexture(); this.output.reshapeFrom2D(); if (this.dataFormat === 'channels_first') { this.output.tensor = this.output.tensor.transpose(2, 0, 1); } } } } exports.default = Conv2D;