UNPKG

keras-js

Version:

Run Keras models in the browser, with GPU support using WebGL

839 lines (728 loc) 29.3 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.default = void 0; var activations = _interopRequireWildcard(require("../../activations")); var _Tensor = _interopRequireDefault(require("../../Tensor")); var _Layer = _interopRequireDefault(require("../../Layer")); var _WebGL = require("../../WebGL2"); var _ndarrayBlasLevel = require("ndarray-blas-level2"); var _ndarrayOps = _interopRequireDefault(require("ndarray-ops")); var _cwise = _interopRequireDefault(require("cwise")); var activationProgramSources = _interopRequireWildcard(require("../../activations/programSources")); function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = Object.defineProperty && Object.getOwnPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : {}; if (desc.get || desc.set) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj.default = obj; return newObj; } } const copyTextureProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D source;\nout vec4 outColor;\n\nvoid main(void) {\n outColor = texture(source, vec2(outTex.x, outTex.y));\n}\n"; const matMulProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D A;\nuniform sampler2D B;\nuniform sampler2D C;\nuniform bool addC;\nout vec4 outColor;\n\nvoid main() {\n ivec2 A_size = textureSize(A, 0);\n ivec2 B_size = textureSize(B, 0);\n int out_x = int(float(B_size[0]) * outTex.x);\n int out_y = int(float(A_size[1]) * outTex.y);\n int commonDim = A_size[0];\n\n float sum = 0.;\n for (int i = 0; i < commonDim; ++i) {\n float a = texelFetch(A, ivec2(i, out_y), 0).r;\n float b = texelFetch(B, ivec2(out_x, i), 0).r;\n sum += a * b;\n }\n\n if (addC) {\n sum += texelFetch(C, ivec2(out_x, 0), 0).r;\n }\n\n outColor = vec4(sum);\n}\n"; const gateSummationProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D t1;\nuniform sampler2D t2;\nuniform sampler2D bias;\nout vec4 outColor;\n\nvoid main() {\n ivec2 size = textureSize(bias, 0);\n int out_x = int(float(size[0]) * outTex.x);\n int out_y = int(float(size[1]) * outTex.y);\n\n float t1_val = texelFetch(t1, ivec2(out_x, out_y), 0).r;\n float t2_val = texelFetch(t2, ivec2(out_x, out_y), 0).r;\n float bias_val = texelFetch(bias, ivec2(out_x, out_y), 0).r;\n\n outColor = vec4(t1_val + t2_val + bias_val);\n}\n"; const gateProductProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D t1;\nuniform sampler2D t2;\nout vec4 outColor;\n\nvoid main() {\n ivec2 size = textureSize(t1, 0);\n int out_x = int(float(size[0]) * outTex.x);\n int out_y = int(float(size[1]) * outTex.y);\n\n float t1_val = texelFetch(t1, ivec2(out_x, out_y), 0).r;\n float t2_val = texelFetch(t2, ivec2(out_x, out_y), 0).r;\n\n outColor = vec4(t1_val * t2_val);\n}\n"; const timestepReadProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform int index;\nout vec4 outColor;\n\nvoid main() {\n ivec2 size = textureSize(x, 0);\n int out_x = int(float(size[0]) * outTex.x);\n\n outColor = vec4(texelFetch(x, ivec2(out_x, index), 0).r);\n}\n"; const timestepWriteProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform sampler2D y;\nuniform int index;\nout vec4 outColor;\n\nvoid main() {\n ivec2 size = textureSize(y, 0);\n int out_x = int(float(size[0]) * outTex.x);\n int out_y = int(float(size[1]) * outTex.y);\n\n if (out_y == index) {\n outColor = vec4(texelFetch(x, ivec2(out_x, 0), 0).r);\n } else {\n outColor = vec4(texelFetch(y, ivec2(out_x, out_y), 0).r);\n }\n}\n"; const updateProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D c;\nuniform sampler2D ctm1;\nuniform sampler2D i;\nuniform sampler2D f;\nout vec4 outColor;\n\nvoid main() {\n ivec2 size = textureSize(c, 0);\n int out_x = int(float(size[0]) * outTex.x);\n int out_y = int(float(size[1]) * outTex.y);\n\n float c_val = texelFetch(c, ivec2(out_x, out_y), 0).r;\n float ctm1_val = texelFetch(ctm1, ivec2(out_x, out_y), 0).r;\n float i_val = texelFetch(i, ivec2(out_x, out_y), 0).r;\n float f_val = texelFetch(f, ivec2(out_x, out_y), 0).r;\n\n outColor = vec4(c_val * i_val + ctm1_val * f_val);\n}\n"; class LSTM extends _Layer.default { constructor(attrs = {}) { super(attrs); Object.defineProperty(this, "_combine", { configurable: true, enumerable: true, writable: true, value: (0, _cwise.default)({ args: ['array', 'array', 'array', 'array'], body: function (_y, _x1, _x2, _b) { _y = _x1 + _x2 + _b; } }) }); Object.defineProperty(this, "_update", { configurable: true, enumerable: true, writable: true, value: (0, _cwise.default)({ args: ['array', 'array', 'array', 'array'], body: function (_c, _ctm1, _i, _f) { _c = _c * _i + _ctm1 * _f; } }) }); this.layerClass = 'LSTM'; const { units = 1, activation = 'tanh', use_bias = true, recurrent_activation = 'hard_sigmoid', return_sequences = false, go_backwards = false, stateful = false } = attrs; this.units = units; this.activation = activation; this.recurrentActivation = recurrent_activation; this.activationFunc = activations[activation]; this.recurrentActivationFunc = activations[recurrent_activation]; this.use_bias = use_bias; this.returnSequences = return_sequences; this.goBackwards = go_backwards; this.stateful = stateful; this.params = this.use_bias ? ['kernel', 'recurrent_kernel', 'bias'] : ['kernel', 'recurrent_kernel']; this.description = `output dimensions: ${this.units}`; this.description += this.activation !== 'linear' ? `, ${this.activation} activation` : ''; this.description += this.recurrentActivation !== 'linear' ? `, ${this.recurrentActivation} recurrent activation` : ''; this.description += this.returnSequences ? `, return sequences` : ''; this.description += this.goBackwards ? `, backward direction` : ''; this.description += this.stateful ? `, stateful` : ''; if (this.gpu) { this.copyTextureProgram = _WebGL.webgl2.compileProgram(copyTextureProgramSource); this.matMulProgram = _WebGL.webgl2.compileProgram(matMulProgramSource); this.activationProgram = _WebGL.webgl2.compileProgram(activationProgramSources[this.activation]); this.recurrentActivationProgram = _WebGL.webgl2.compileProgram(activationProgramSources[this.recurrentActivation]); this.gateSummationProgram = _WebGL.webgl2.compileProgram(gateSummationProgramSource); this.gateProductProgram = _WebGL.webgl2.compileProgram(gateProductProgramSource); this.timestepReadProgram = _WebGL.webgl2.compileProgram(timestepReadProgramSource); this.timestepWriteProgram = _WebGL.webgl2.compileProgram(timestepWriteProgramSource); this.updateProgram = _WebGL.webgl2.compileProgram(updateProgramSource); } } setWeights(weightsArr) { super.setWeights(weightsArr); const shape_W = this.weights['kernel'].tensor.shape; this.weights['W_i'] = new _Tensor.default([], [shape_W[0], this.units]); this.weights['W_f'] = new _Tensor.default([], [shape_W[0], this.units]); this.weights['W_c'] = new _Tensor.default([], [shape_W[0], this.units]); this.weights['W_o'] = new _Tensor.default([], [shape_W[0], this.units]); _ndarrayOps.default.assign(this.weights['W_i'].tensor, this.weights['kernel'].tensor.hi(shape_W[0], this.units).lo(0, 0)); _ndarrayOps.default.assign(this.weights['W_f'].tensor, this.weights['kernel'].tensor.hi(shape_W[0], 2 * this.units).lo(0, this.units)); _ndarrayOps.default.assign(this.weights['W_c'].tensor, this.weights['kernel'].tensor.hi(shape_W[0], 3 * this.units).lo(0, 2 * this.units)); _ndarrayOps.default.assign(this.weights['W_o'].tensor, this.weights['kernel'].tensor.hi(shape_W[0], 4 * this.units).lo(0, 3 * this.units)); const shape_U = this.weights['recurrent_kernel'].tensor.shape; this.weights['U_i'] = new _Tensor.default([], [shape_U[0], this.units]); this.weights['U_f'] = new _Tensor.default([], [shape_U[0], this.units]); this.weights['U_c'] = new _Tensor.default([], [shape_U[0], this.units]); this.weights['U_o'] = new _Tensor.default([], [shape_U[0], this.units]); _ndarrayOps.default.assign(this.weights['U_i'].tensor, this.weights['recurrent_kernel'].tensor.hi(shape_U[0], this.units).lo(0, 0)); _ndarrayOps.default.assign(this.weights['U_f'].tensor, this.weights['recurrent_kernel'].tensor.hi(shape_U[0], 2 * this.units).lo(0, this.units)); _ndarrayOps.default.assign(this.weights['U_c'].tensor, this.weights['recurrent_kernel'].tensor.hi(shape_U[0], 3 * this.units).lo(0, 2 * this.units)); _ndarrayOps.default.assign(this.weights['U_o'].tensor, this.weights['recurrent_kernel'].tensor.hi(shape_U[0], 4 * this.units).lo(0, 3 * this.units)); this.weights['b_i'] = new _Tensor.default([], [this.units]); this.weights['b_f'] = new _Tensor.default([], [this.units]); this.weights['b_c'] = new _Tensor.default([], [this.units]); this.weights['b_o'] = new _Tensor.default([], [this.units]); if (this.use_bias) { _ndarrayOps.default.assign(this.weights['b_i'].tensor, this.weights['bias'].tensor.hi(this.units).lo(0)); _ndarrayOps.default.assign(this.weights['b_f'].tensor, this.weights['bias'].tensor.hi(2 * this.units).lo(this.units)); _ndarrayOps.default.assign(this.weights['b_c'].tensor, this.weights['bias'].tensor.hi(3 * this.units).lo(2 * this.units)); _ndarrayOps.default.assign(this.weights['b_o'].tensor, this.weights['bias'].tensor.hi(4 * this.units).lo(3 * this.units)); } if (this.gpu) { const names = ['W_i', 'W_f', 'W_c', 'W_o', 'U_i', 'U_f', 'U_c', 'U_o', 'b_i', 'b_f', 'b_c', 'b_o']; names.forEach(name => { this.weights[name].createGLTexture({ type: '2d', format: 'float' }); }); } } call(x) { if (this.gpu) { this._callGPU(x); } else { this._callCPU(x); } return this.output; } _callCPU(x) { const dimInputGate = this.weights['b_i'].tensor.shape[0]; const dimCandidate = this.weights['b_c'].tensor.shape[0]; const dimForgetGate = this.weights['b_f'].tensor.shape[0]; const dimOutputGate = this.weights['b_o'].tensor.shape[0]; const currentInputGateState = new _Tensor.default([], [dimInputGate]); const tempXI = new _Tensor.default([], [dimInputGate]); const tempHI = new _Tensor.default([], [dimInputGate]); const currentForgetGateState = new _Tensor.default([], [dimForgetGate]); const tempXF = new _Tensor.default([], [dimForgetGate]); const tempHF = new _Tensor.default([], [dimForgetGate]); const currentOutputGateState = new _Tensor.default([], [dimOutputGate]); const tempXO = new _Tensor.default([], [dimOutputGate]); const tempHO = new _Tensor.default([], [dimOutputGate]); const currentCandidate = new _Tensor.default([], [dimCandidate]); const tempXC = new _Tensor.default([], [dimCandidate]); const tempHC = new _Tensor.default([], [dimCandidate]); const previousCandidate = this.stateful && this.previousCandidate ? this.previousCandidate : new _Tensor.default([], [dimCandidate]); const currentHiddenState = this.stateful && this.currentHiddenState ? this.currentHiddenState : new _Tensor.default([], [dimCandidate]); const previousHiddenState = new _Tensor.default([], [dimCandidate]); this.hiddenStateSequence = new _Tensor.default([], [x.tensor.shape[0], dimCandidate]); const currentX = new _Tensor.default([], [x.tensor.shape[1]]); const _step = () => { _ndarrayOps.default.assign(previousHiddenState.tensor, currentHiddenState.tensor); (0, _ndarrayBlasLevel.gemv)(1, this.weights['W_i'].tensor.transpose(1, 0), currentX.tensor, 1, tempXI.tensor); (0, _ndarrayBlasLevel.gemv)(1, this.weights['U_i'].tensor.transpose(1, 0), previousHiddenState.tensor, 1, tempHI.tensor); this._combine(currentInputGateState.tensor, tempXI.tensor, tempHI.tensor, this.weights['b_i'].tensor); this.recurrentActivationFunc(currentInputGateState); (0, _ndarrayBlasLevel.gemv)(1, this.weights['W_f'].tensor.transpose(1, 0), currentX.tensor, 1, tempXF.tensor); (0, _ndarrayBlasLevel.gemv)(1, this.weights['U_f'].tensor.transpose(1, 0), previousHiddenState.tensor, 1, tempHF.tensor); this._combine(currentForgetGateState.tensor, tempXF.tensor, tempHF.tensor, this.weights['b_f'].tensor); this.recurrentActivationFunc(currentForgetGateState); (0, _ndarrayBlasLevel.gemv)(1, this.weights['W_o'].tensor.transpose(1, 0), currentX.tensor, 1, tempXO.tensor); (0, _ndarrayBlasLevel.gemv)(1, this.weights['U_o'].tensor.transpose(1, 0), previousHiddenState.tensor, 1, tempHO.tensor); this._combine(currentOutputGateState.tensor, tempXO.tensor, tempHO.tensor, this.weights['b_o'].tensor); this.recurrentActivationFunc(currentOutputGateState); (0, _ndarrayBlasLevel.gemv)(1, this.weights['W_c'].tensor.transpose(1, 0), currentX.tensor, 1, tempXC.tensor); (0, _ndarrayBlasLevel.gemv)(1, this.weights['U_c'].tensor.transpose(1, 0), previousHiddenState.tensor, 1, tempHC.tensor); this._combine(currentCandidate.tensor, tempXC.tensor, tempHC.tensor, this.weights['b_c'].tensor); this.activationFunc(currentCandidate); this._update(currentCandidate.tensor, previousCandidate.tensor, currentInputGateState.tensor, currentForgetGateState.tensor); _ndarrayOps.default.assign(previousCandidate.tensor, currentCandidate.tensor); this.activationFunc(currentCandidate); _ndarrayOps.default.mul(currentHiddenState.tensor, currentOutputGateState.tensor, currentCandidate.tensor); }; for (let i = 0, len = x.tensor.shape[0]; i < len; i++) { const inputIndex = this.goBackwards ? len - i - 1 : i; _ndarrayOps.default.assign(currentX.tensor, x.tensor.pick(inputIndex, null)); const tempTensors = [tempXI, tempHI, tempXF, tempHF, tempXO, tempHO, tempXC, tempHC]; tempTensors.forEach(temp => _ndarrayOps.default.assigns(temp.tensor, 0)); _step(); _ndarrayOps.default.assign(this.hiddenStateSequence.tensor.pick(i, null), currentHiddenState.tensor); } if (this.returnSequences) { this.output = this.hiddenStateSequence; } else { this.output = currentHiddenState; } if (this.stateful) { this.previousCandidate = previousCandidate; this.currentHiddenState = currentHiddenState; } } _stepGPU() { _WebGL.webgl2.runProgram({ program: this.copyTextureProgram, output: this.previousHiddenState, inputs: [{ input: this.currentHiddenState, name: 'source' }] }); _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.tempXI, inputs: [{ input: this.currentX, name: 'A' }, { input: this.weights['W_i'], name: 'B' }], uniforms: [{ value: 0, type: 'bool', name: 'addC' }] }); _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.tempHI, inputs: [{ input: this.previousHiddenState, name: 'A' }, { input: this.weights['U_i'], name: 'B' }], uniforms: [{ value: 0, type: 'bool', name: 'addC' }] }); _WebGL.webgl2.runProgram({ program: this.gateSummationProgram, output: this.currentInputGateStatePreactiv, inputs: [{ input: this.tempXI, name: 't1' }, { input: this.tempHI, name: 't2' }, { input: this.weights['b_i'], name: 'bias' }] }); if (this.recurrentActivation !== 'linear') { _WebGL.webgl2.runProgram({ program: this.recurrentActivationProgram, output: this.currentInputGateState, inputs: [{ input: this.currentInputGateStatePreactiv, name: 'x' }] }); } else { this.currentInputGateState = this.currentInputGateStatePreactiv; } _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.tempXF, inputs: [{ input: this.currentX, name: 'A' }, { input: this.weights['W_f'], name: 'B' }], uniforms: [{ value: 0, type: 'bool', name: 'addC' }] }); _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.tempHF, inputs: [{ input: this.previousHiddenState, name: 'A' }, { input: this.weights['U_f'], name: 'B' }], uniforms: [{ value: 0, type: 'bool', name: 'addC' }] }); _WebGL.webgl2.runProgram({ program: this.gateSummationProgram, output: this.currentForgetGateStatePreactiv, inputs: [{ input: this.tempXF, name: 't1' }, { input: this.tempHF, name: 't2' }, { input: this.weights['b_f'], name: 'bias' }] }); if (this.recurrentActivation !== 'linear') { _WebGL.webgl2.runProgram({ program: this.recurrentActivationProgram, output: this.currentForgetGateState, inputs: [{ input: this.currentForgetGateStatePreactiv, name: 'x' }] }); } else { this.currentForgetGateState = this.currentForgetGateStatePreactiv; } _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.tempXO, inputs: [{ input: this.currentX, name: 'A' }, { input: this.weights['W_o'], name: 'B' }], uniforms: [{ value: 0, type: 'bool', name: 'addC' }] }); _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.tempHO, inputs: [{ input: this.previousHiddenState, name: 'A' }, { input: this.weights['U_o'], name: 'B' }], uniforms: [{ value: 0, type: 'bool', name: 'addC' }] }); _WebGL.webgl2.runProgram({ program: this.gateSummationProgram, output: this.currentOutputGateStatePreactiv, inputs: [{ input: this.tempXO, name: 't1' }, { input: this.tempHO, name: 't2' }, { input: this.weights['b_o'], name: 'bias' }] }); if (this.recurrentActivation !== 'linear') { _WebGL.webgl2.runProgram({ program: this.recurrentActivationProgram, output: this.currentOutputGateState, inputs: [{ input: this.currentOutputGateStatePreactiv, name: 'x' }] }); } else { this.currentOutputGateState = this.currentOutputGateStatePreactiv; } _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.tempXC, inputs: [{ input: this.currentX, name: 'A' }, { input: this.weights['W_c'], name: 'B' }], uniforms: [{ value: 0, type: 'bool', name: 'addC' }] }); _WebGL.webgl2.runProgram({ program: this.matMulProgram, output: this.tempHC, inputs: [{ input: this.previousHiddenState, name: 'A' }, { input: this.weights['U_c'], name: 'B' }], uniforms: [{ value: 0, type: 'bool', name: 'addC' }] }); _WebGL.webgl2.runProgram({ program: this.gateSummationProgram, output: this.currentCandidatePreactiv, inputs: [{ input: this.tempXC, name: 't1' }, { input: this.tempHC, name: 't2' }, { input: this.weights['b_c'], name: 'bias' }] }); if (this.activation !== 'linear') { _WebGL.webgl2.runProgram({ program: this.activationProgram, output: this.currentCandidate, inputs: [{ input: this.currentCandidatePreactiv, name: 'x' }] }); } else { this.currentCandidate = this.currentCandidatePreactiv; } _WebGL.webgl2.runProgram({ program: this.copyTextureProgram, output: this.currentCandidateCopy, inputs: [{ input: this.currentCandidate, name: 'source' }] }); _WebGL.webgl2.runProgram({ program: this.updateProgram, output: this.currentCandidate, inputs: [{ input: this.currentCandidateCopy, name: 'c' }, { input: this.previousCandidate, name: 'ctm1' }, { input: this.currentInputGateState, name: 'i' }, { input: this.currentForgetGateState, name: 'f' }] }); _WebGL.webgl2.runProgram({ program: this.copyTextureProgram, output: this.previousCandidate, inputs: [{ input: this.currentCandidate, name: 'source' }] }); _WebGL.webgl2.runProgram({ program: this.copyTextureProgram, output: this.currentCandidatePreactiv, inputs: [{ input: this.currentCandidate, name: 'source' }] }); if (this.activation !== 'linear') { _WebGL.webgl2.runProgram({ program: this.activationProgram, output: this.currentCandidate, inputs: [{ input: this.currentCandidatePreactiv, name: 'x' }] }); } else { this.currentCandidate = this.currentCandidatePreactiv; } _WebGL.webgl2.runProgram({ program: this.gateProductProgram, output: this.currentHiddenState, inputs: [{ input: this.currentOutputGateState, name: 't1' }, { input: this.currentCandidate, name: 't2' }] }); } _callGPU(x) { if (!x.glTexture) { x.createGLTexture({ type: '2d', format: 'float' }); } const dimInputGate = this.weights['b_i'].glTextureShape[1]; const dimCandidate = this.weights['b_c'].glTextureShape[1]; const dimForgetGate = this.weights['b_f'].glTextureShape[1]; const dimOutputGate = this.weights['b_o'].glTextureShape[1]; if (!this.currentInputGateState) { this.currentInputGateState = new _Tensor.default([], [dimInputGate]); this.currentInputGateState.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentInputGateStatePreactiv) { this.currentInputGateStatePreactiv = new _Tensor.default([], [dimInputGate]); this.currentInputGateStatePreactiv.createGLTexture({ type: '2d', format: 'float' }); } if (!this.tempXI) { this.tempXI = new _Tensor.default([], [dimInputGate]); this.tempXI.createGLTexture({ type: '2d', format: 'float' }); } if (!this.tempHI) { this.tempHI = new _Tensor.default([], [dimInputGate]); this.tempHI.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentForgetGateState) { this.currentForgetGateState = new _Tensor.default([], [dimForgetGate]); this.currentForgetGateState.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentForgetGateStatePreactiv) { this.currentForgetGateStatePreactiv = new _Tensor.default([], [dimForgetGate]); this.currentForgetGateStatePreactiv.createGLTexture({ type: '2d', format: 'float' }); } if (!this.tempXF) { this.tempXF = new _Tensor.default([], [dimForgetGate]); this.tempXF.createGLTexture({ type: '2d', format: 'float' }); } if (!this.tempHF) { this.tempHF = new _Tensor.default([], [dimForgetGate]); this.tempHF.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentOutputGateState) { this.currentOutputGateState = new _Tensor.default([], [dimOutputGate]); this.currentOutputGateState.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentOutputGateStatePreactiv) { this.currentOutputGateStatePreactiv = new _Tensor.default([], [dimOutputGate]); this.currentOutputGateStatePreactiv.createGLTexture({ type: '2d', format: 'float' }); } if (!this.tempXO) { this.tempXO = new _Tensor.default([], [dimOutputGate]); this.tempXO.createGLTexture({ type: '2d', format: 'float' }); } if (!this.tempHO) { this.tempHO = new _Tensor.default([], [dimOutputGate]); this.tempHO.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentCandidate) { this.currentCandidate = new _Tensor.default([], [dimCandidate]); this.currentCandidate.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentCandidateCopy) { this.currentCandidateCopy = new _Tensor.default([], [dimCandidate]); this.currentCandidateCopy.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentCandidatePreactiv) { this.currentCandidatePreactiv = new _Tensor.default([], [dimCandidate]); this.currentCandidatePreactiv.createGLTexture({ type: '2d', format: 'float' }); } if (!this.tempXC) { this.tempXC = new _Tensor.default([], [dimCandidate]); this.tempXC.createGLTexture({ type: '2d', format: 'float' }); } if (!this.tempHC) { this.tempHC = new _Tensor.default([], [dimCandidate]); this.tempHC.createGLTexture({ type: '2d', format: 'float' }); } if (!this.previousCandidate || !this.stateful) { this.previousCandidate = new _Tensor.default([], [dimCandidate]); this.previousCandidate.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentHiddenState || !this.stateful) { this.currentHiddenState = new _Tensor.default([], [dimCandidate]); this.currentHiddenState.createGLTexture({ type: '2d', format: 'float' }); } if (!this.previousHiddenState) { this.previousHiddenState = new _Tensor.default([], [dimCandidate]); this.previousHiddenState.createGLTexture({ type: '2d', format: 'float' }); } if (!this.hiddenStateSequence) { this.hiddenStateSequence = new _Tensor.default([], [x.glTextureShape[0], dimCandidate]); this.hiddenStateSequence.createGLTexture({ type: '2d', format: 'float' }); } if (!this.hiddenStateSequenceCopy) { this.hiddenStateSequenceCopy = new _Tensor.default([], [x.glTextureShape[0], dimCandidate]); this.hiddenStateSequenceCopy.createGLTexture({ type: '2d', format: 'float' }); } if (!this.currentX) { this.currentX = new _Tensor.default([], [x.glTextureShape[1]]); this.currentX.createGLTexture({ type: '2d', format: 'float' }); } for (let i = 0, len = x.glTextureShape[0]; i < len; i++) { const inputIndex = this.goBackwards ? len - i - 1 : i; _WebGL.webgl2.runProgram({ program: this.timestepReadProgram, output: this.currentX, inputs: [{ input: x, name: 'x' }], uniforms: [{ value: inputIndex, type: 'int', name: 'index' }] }); this._stepGPU(); if (this.returnSequences) { _WebGL.webgl2.runProgram({ program: this.copyTextureProgram, output: this.hiddenStateSequenceCopy, inputs: [{ input: this.hiddenStateSequence, name: 'source' }] }); _WebGL.webgl2.runProgram({ program: this.timestepWriteProgram, output: this.hiddenStateSequence, inputs: [{ input: this.currentHiddenState, name: 'x' }, { input: this.hiddenStateSequenceCopy, name: 'y' }], uniforms: [{ value: i, type: 'int', name: 'index' }] }); } } if (this.returnSequences) { this.output = this.hiddenStateSequence; } else { this.output = this.currentHiddenState; } if (this.outbound.length === 0) { this.output.transferFromGLTexture(); } } } exports.default = LSTM;