UNPKG

keras-js

Version:

Run Keras models in the browser, with GPU support using WebGL

95 lines (76 loc) 2.68 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.default = void 0; var _Layer = _interopRequireDefault(require("../../Layer")); var _Tensor = _interopRequireDefault(require("../../Tensor")); var _WebGL = require("../../WebGL2"); var _ndarrayOps = _interopRequireDefault(require("ndarray-ops")); function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } const programSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform sampler2D embeddings;\nout vec4 outColor;\n\nvoid main() {\n ivec2 x_size = textureSize(x, 0);\n ivec2 embeddings_size = textureSize(embeddings, 0);\n int out_x = int(float(embeddings_size[0]) * outTex.x);\n int out_y = int(float(x_size[0]) * outTex.y);\n\n int index = int(texelFetch(x, ivec2(out_y, 0), 0).r);\n outColor = texelFetch(embeddings, ivec2(out_x, index), 0);\n}\n"; class Embedding extends _Layer.default { constructor(attrs = {}) { super(attrs); this.layerClass = 'Embedding'; const { input_dim = 1, output_dim = 1, input_length = 0, mask_zero = false } = attrs; this.description = `output dimensions: ${output_dim}`; this.inputDim = input_dim; this.outputDim = output_dim; this.inputLength = input_length; this.maskZero = mask_zero; this.params = ['embeddings']; if (this.gpu) { this.program = _WebGL.webgl2.compileProgram(programSource); } } call(x) { if (this.gpu) { this._callGPU(x); } else { this._callCPU(x); } return this.output; } _callCPU(x) { this.output = new _Tensor.default([], [x.tensor.shape[0], this.weights['embeddings'].tensor.shape[1]]); for (let i = 0, len = x.tensor.shape[0]; i < len; i++) { _ndarrayOps.default.assign(this.output.tensor.pick(i, null), this.weights['embeddings'].tensor.pick(x.tensor.get(i), null)); } } _callGPU(x) { if (!x.glTexture) { x.createGLTexture({ type: '2d', format: 'float' }); } if (!this.output) { this.output = new _Tensor.default([], [x.glTextureShape[1], this.weights['embeddings'].glTextureShape[1]]); this.output.createGLTexture({ type: '2d', format: 'float' }); } _WebGL.webgl2.runProgram({ program: this.program, output: this.output, inputs: [{ input: x, name: 'x' }, { input: this.weights['embeddings'], name: 'embeddings' }] }); if (this.outbound.length === 0) { this.output.transferFromGLTexture(); } } } exports.default = Embedding;