@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
53 lines • 2.17 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
var util_1 = require("../util");
/**
* Validate gather nd inputs.
*
* @param tensor The tensor contains the source values.
* @param indices The tensor contains the indices to slice the source.
*
* @returns [resultShape, numUpdates, sliceSize, strides]
*/
function prepareAndValidate(tensor, indices) {
if (tensor.rank < 1) {
throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
(" but the rank was " + tensor.rank + "."));
}
if (indices.rank < 1) {
throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
(" but the rank was " + indices.rank + "."));
}
if (indices.dtype !== 'int32') {
throw new Error('tf.gatherND() expects the indices to be int32 type,' +
(" but the dtype was " + indices.dtype + "."));
}
if (indices.shape[indices.rank - 1] > tensor.rank) {
throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
(indices.shape[indices.rank - 1] + " vs. " + tensor.rank));
}
if (tensor.size === 0) {
throw new Error('Requested more than 0 entries, but input is empty.' +
(" Input shape: " + tensor.shape + "."));
}
var indicesShape = indices.shape;
var sliceRank = indicesShape[indicesShape.length - 1];
// The result shape is
// indices.shape[:-1] + params.shape[indices.shape[-1]:]
var nResult = 1;
for (var i = 0; i < indicesShape.length - 1; ++i) {
nResult *= indicesShape[i];
}
var inputShape = tensor.shape;
var resultShape = indicesShape.slice();
resultShape.pop();
var sliceSize = 1;
for (var i = sliceRank; i < tensor.rank; ++i) {
sliceSize *= inputShape[i];
resultShape.push(inputShape[i]);
}
var strides = util_1.computeStrides(tensor.shape).map(function (stride) { return stride / sliceSize; }).concat([1]).slice(0, sliceRank);
return [resultShape, nResult, sliceSize, strides];
}
exports.prepareAndValidate = prepareAndValidate;
//# sourceMappingURL=gather_nd_util.js.map