keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
88 lines (68 loc) • 2.07 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _Layer = _interopRequireDefault(require("../../Layer"));
var _Tensor = _interopRequireDefault(require("../../Tensor"));
var _WebGL = require("../../WebGL2");
var _ndarrayUnsqueeze = _interopRequireDefault(require("ndarray-unsqueeze"));
var _ndarrayTile = _interopRequireDefault(require("ndarray-tile"));
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
const programSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D x;\nout vec4 outColor;\n\nvoid main() {\n outColor = texture(x, vec2(outTex.x, 0));\n}\n";
class RepeatVector extends _Layer.default {
constructor(attrs = {}) {
super(attrs);
this.layerClass = 'RepeatVector';
const {
n = 1
} = attrs;
this.n = n;
this.description = `n = ${n}`;
if (this.gpu) {
this.program = _WebGL.webgl2.compileProgram(programSource);
}
}
call(x) {
if (this.gpu) {
this._callGPU(x);
} else {
this._callCPU(x);
}
return this.output;
}
_callCPU(x) {
if (x.tensor.shape.length !== 1) {
this.throwError('Only 1D tensor inputs allowed.');
}
this.output = new _Tensor.default([], [this.n, x.tensor.shape[1]]);
this.output.tensor = (0, _ndarrayTile.default)((0, _ndarrayUnsqueeze.default)(x.tensor, 0), [this.n, 1]);
}
_callGPU(x) {
if (!x.glTexture) {
x.createGLTexture({
type: '2d',
format: 'float'
});
}
if (!this.output) {
this.output = new _Tensor.default([], [this.n, x.glTextureShape[1]]);
this.output.createGLTexture({
type: '2d',
format: 'float'
});
}
_WebGL.webgl2.runProgram({
program: this.program,
output: this.output,
inputs: [{
input: x,
name: 'x'
}]
});
if (this.outbound.length === 0) {
this.output.transferFromGLTexture();
}
}
}
exports.default = RepeatVector;