keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
143 lines (112 loc) • 5.03 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _sum2 = _interopRequireDefault(require("lodash/sum"));
var _Merge2 = _interopRequireDefault(require("./_Merge"));
var _Tensor = _interopRequireDefault(require("../../Tensor"));
var _WebGL = require("../../WebGL2");
var _createGLSLProgram = _interopRequireDefault(require("../../webgl/dynamic/createGLSLProgram"));
var tensorUtils = _interopRequireWildcard(require("../../utils/tensorUtils"));
var _ndarrayConcatRows = _interopRequireDefault(require("ndarray-concat-rows"));
function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = Object.defineProperty && Object.getOwnPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : {}; if (desc.get || desc.set) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj.default = obj; return newObj; } }
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
class Concatenate extends _Merge2.default {
constructor(attrs = {}) {
super(attrs);
this.layerClass = 'Concatenate';
this.mode = 'concat';
const {
axis = -1
} = attrs;
this.concatAxis = axis <= 0 ? axis : axis - 1;
}
_callCPU(inputs) {
const outputShape = inputs[0].tensor.shape.slice();
const _concatAxis = this.concatAxis < 0 ? outputShape.length + this.concatAxis : this.concatAxis;
inputs.slice(1, inputs.length).forEach(x => {
const d = x.tensor.shape.slice()[_concatAxis];
outputShape[_concatAxis] += d;
});
this.output = new _Tensor.default([], outputShape);
if (_concatAxis === 0) {
(0, _ndarrayConcatRows.default)(this.output.tensor, inputs.map(x => x.tensor));
} else {
let dimsAxisSwap = [_concatAxis];
for (let i = 0; i < inputs[0].tensor.shape.length; i++) {
if (i !== _concatAxis) dimsAxisSwap.push(i);
}
(0, _ndarrayConcatRows.default)(this.output.tensor.transpose(...dimsAxisSwap), inputs.map(x => x.tensor.transpose(...dimsAxisSwap)));
}
}
_callGPU(inputs) {
inputs.forEach(input => {
if (!input.glTexture && !input.glTextureFragments) {
input.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
}
});
const outputShape = inputs[0].glTextureShape.slice();
let _concatAxis = 1;
if (inputs[0].is2DReshaped) {
if (this.concatAxis === -1 || this.concatAxis === inputs[0].originalShape.length - 1) {
_concatAxis = 1;
} else {
this.throwError('specified axis not supported for now.');
}
} else {
if (this.concatAxis === -1 || this.concatAxis === 1) {
_concatAxis = 1;
} else if (this.concatAxis === -2 || this.concatAxis === 0) {
_concatAxis = 0;
} else {
this.throwError('specified axis not supported for now.');
}
}
outputShape[_concatAxis] = (0, _sum2.default)(inputs.map(input => input.glTextureShape[_concatAxis]));
if (!this.output) {
this.output = new _Tensor.default([], outputShape);
this.output.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: _concatAxis === 1
});
if (inputs[0].is1D) {
this.output.is1D = inputs[0].is1D;
} else if (inputs[0].is2DReshaped) {
this.output.is2DReshaped = inputs[0].is2DReshaped;
this.output.originalShape = inputs[0].originalShape.slice();
const _concatAxis = this.concatAxis < 0 ? this.output.originalShape.length + this.concatAxis : this.concatAxis;
this.output.originalShape[_concatAxis] = (0, _sum2.default)(inputs.map(input => input.originalShape[_concatAxis]));
this.output.indicesForReshaped = tensorUtils.createIndicesFor2DReshaped(this.output.originalShape, false, _concatAxis);
}
}
if (!this.mergeProgram) {
const outputShape = this.output.glTextureFragments ? this.output.glTextureFragmentShape : this.output.glTextureShape;
const mergeProgramSource = (0, _createGLSLProgram.default)('concatenate', inputs.length, inputs.map(input => input.glTextureShape), outputShape, _concatAxis);
this.mergeProgram = _WebGL.webgl2.compileProgram(mergeProgramSource);
}
_WebGL.webgl2.runProgram({
program: this.mergeProgram,
output: this.output,
inputs: inputs.map((input, i) => ({
input,
name: `inputs[${i}]`
})),
supportsTextureFragments: true
});
if (this.outbound.length === 0) {
this.output.transferFromGLTexture();
if (this.output.is2DReshaped) {
this.output.reshapeFrom2D();
} else if (this.output.is2DSquareReshaped) {
this.output.reshapeFrom2DSquare();
}
}
}
}
exports.default = Concatenate;