@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
68 lines • 3.19 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
var environment_1 = require("../environment");
var tensor_util_env_1 = require("../tensor_util_env");
var util_1 = require("../util");
var util_2 = require("../util");
var concat_util_1 = require("./concat_util");
var operation_1 = require("./operation");
var tensor_ops_1 = require("./tensor_ops");
function concat1d_(tensors) {
return exports.concat(tensors, 0);
}
function concat2d_(tensors, axis) {
return exports.concat(tensors, axis);
}
function concat3d_(tensors, axis) {
return exports.concat(tensors, axis);
}
function concat4d_(tensors, axis) {
return exports.concat(tensors, axis);
}
function concat_(tensors, axis) {
if (axis === void 0) { axis = 0; }
util_1.assert(tensors.length >= 1, 'Pass at least one tensor to concat');
var $tensors = tensor_util_env_1.convertToTensorArray(tensors, 'tensors', 'concat');
axis = util_2.parseAxisParam(axis, $tensors[0].shape)[0];
var outShape = concat_util_1.computeOutShape($tensors.map(function (t) { return t.shape; }), axis);
if (util_1.sizeFromShape(outShape) === 0) {
return tensor_ops_1.tensor([], outShape);
}
$tensors = $tensors.filter(function (t) { return t.size > 0; });
if ($tensors.length === 1) {
return $tensors[0];
}
var shapes = $tensors.map(function (t) { return t.shape; });
concat_util_1.assertParamsConsistent(shapes, axis);
var der = function (dy) {
var sizeSplits = shapes.map(function (s) { return s[axis]; });
var derTensors = exports.split(dy, sizeSplits, axis);
return derTensors.map(function (t) { return function () { return t; }; });
};
var inputs = $tensors;
return environment_1.ENV.engine.runKernel(function (backend) { return backend.concat($tensors, axis); }, inputs, der);
}
function split_(x, numOrSizeSplits, axis) {
if (axis === void 0) { axis = 0; }
var $x = tensor_util_env_1.convertToTensor(x, 'x', 'split');
axis = util_2.parseAxisParam(axis, $x.shape)[0];
var splitSizes;
if (typeof (numOrSizeSplits) === 'number') {
util_1.assert($x.shape[axis] % numOrSizeSplits === 0, 'Number of splits must evenly divide the axis.');
splitSizes =
new Array(numOrSizeSplits).fill($x.shape[axis] / numOrSizeSplits);
}
else {
util_1.assert($x.shape[axis] === numOrSizeSplits.reduce(function (a, b) { return a + b; }), 'The sum of sizes must match the size of the axis dimension.');
splitSizes = numOrSizeSplits;
}
var der = function (dy) { return ({ $x: function () { return exports.concat(dy, axis); } }); };
return environment_1.ENV.engine.runKernel(function (backend) { return backend.split($x, splitSizes, axis); }, { $x: $x }, der);
}
exports.concat = operation_1.op({ concat_: concat_ });
exports.concat1d = operation_1.op({ concat1d_: concat1d_ });
exports.concat2d = operation_1.op({ concat2d_: concat2d_ });
exports.concat3d = operation_1.op({ concat3d_: concat3d_ });
exports.concat4d = operation_1.op({ concat4d_: concat4d_ });
exports.split = operation_1.op({ split_: split_ });
//# sourceMappingURL=concat_split.js.map