@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
65 lines • 2.87 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var engine_1 = require("../engine");
var tensor_util_env_1 = require("../tensor_util_env");
var operation_1 = require("./operation");
/**
* Gather slices from input tensor into a Tensor with shape specified by
* `indices`.
*
* `indices` is an K-dimensional integer tensor, best thought of as a
* (K-1)-dimensional tensor of indices into input, where each element defines a
* slice of input:
* output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]]
*
* Whereas in `tf.gather`, `indices` defines slices into the first dimension of
* input, in `tf.gatherND`, `indices` defines slices into the first N dimensions
* of input, where N = indices.shape[-1].
*
* The last dimension of indices can be at most the rank of input:
* indices.shape[-1] <= input.rank
*
* The last dimension of `indices` corresponds to elements
* (if indices.shape[-1] == input.rank) or slices
* (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of
* input.
* The output tensor has shape
* indices.shape[:-1] + input.shape[indices.shape[-1]:]
*
* Note that on CPU, if an out of bound index is found, an error is returned. On
* GPU, if an out of bound index is found, a 0 is stored in the corresponding
* output value.
*
* ```js
* const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');
* const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);
* tf.gatherND(input, indices).print() // [10, 11]
* ```
*
* @param x The tensor from which to gather values.
* @param indices Index tensor, must be of type int32.
*/
/** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */
function gatherND_(x, indices) {
var $indices = tensor_util_env_1.convertToTensor(indices, 'indices', 'gatherND', 'int32');
var $x = tensor_util_env_1.convertToTensor(x, 'x', 'gatherND');
return engine_1.ENGINE.runKernelFunc(function (backend) { return backend.gatherND($x, $indices); }, { x: $x, indices: $indices }, null /* backward */, 'GatherNd');
}
exports.gatherND = operation_1.op({ gatherND_: gatherND_ });
//# sourceMappingURL=gather_nd.js.map