keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
84 lines (67 loc) • 2.12 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _isEqual2 = _interopRequireDefault(require("lodash/isEqual"));
var _Layer = _interopRequireDefault(require("../Layer"));
var _Tensor = _interopRequireDefault(require("../Tensor"));
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
class InputLayer extends _Layer.default {
constructor(attrs = {}) {
super(attrs);
this.layerClass = 'InputLayer';
const {
shape = []
} = attrs;
this.shape = attrs.batch_input_shape && attrs.batch_input_shape.length ? attrs.batch_input_shape.slice(1) : shape;
this.description = `shape: ${JSON.stringify(this.shape)}`;
}
call(x) {
if (this.gpu) {
this._callGPU(x);
} else {
this._callCPU(x);
}
return this.output;
}
_callCPU(x) {
this.inputShape = x.tensor.shape;
if (!(0, _isEqual2.default)(this.inputShape, this.shape)) {
this.throwError(`input tensor shape ${x.tensor.shape} does not match specified shape ${this.shape}.`);
}
this.output = new _Tensor.default(x.tensor.data, x.tensor.shape);
}
_callGPU(x) {
if (!x.glTexture && !x.glTextureFragments) {
this.inputShape = x.tensor.shape;
} else {
if (x.is2DReshaped || x.is2DSquareReshaped) {
this.inputShape = x.originalShape;
} else {
this.inputShape = x.tensor.shape;
}
}
if (!(0, _isEqual2.default)(this.inputShape, this.shape)) {
this.throwError(`input tensor shape ${x.tensor.shape} does not match specified shape ${this.shape}.`);
}
if (!x.glTexture) {
if (x.tensor.shape.length <= 2) {
x.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
} else if (x.tensor.shape.length > 2) {
x.reshapeTo2D();
x.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
}
}
this.output = x;
}
}
exports.default = InputLayer;