@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
130 lines • 6.69 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
var tf = require("../index");
var jasmine_util_1 = require("../jasmine_util");
var test_util_1 = require("../test_util");
jasmine_util_1.describeWithFlags('scatterND', test_util_1.ALL_ENVS, function () {
it('should work for 2d', function () {
var indices = tf.tensor1d([0, 4, 2], 'int32');
var updates = tf.tensor2d([100, 101, 102, 777, 778, 779, 1000, 1001, 1002], [3, 3], 'int32');
var shape = [5, 3];
var result = tf.scatterND(indices, updates, shape);
expect(result.shape).toEqual(shape);
expect(result.dtype).toEqual(updates.dtype);
test_util_1.expectArraysClose(result, [100, 101, 102, 0, 0, 0, 1000, 1001, 1002, 0, 0, 0, 777, 778, 779]);
});
it('should work for simple 1d', function () {
var indices = tf.tensor1d([3], 'int32');
var updates = tf.tensor1d([101], 'float32');
var shape = [5];
var result = tf.scatterND(indices, updates, shape);
expect(result.shape).toEqual(shape);
expect(result.dtype).toEqual(updates.dtype);
test_util_1.expectArraysClose(result, [0, 0, 0, 101, 0]);
});
it('should work for multiple 1d', function () {
var indices = tf.tensor1d([0, 4, 2], 'int32');
var updates = tf.tensor1d([100, 101, 102], 'float32');
var shape = [5];
var result = tf.scatterND(indices, updates, shape);
expect(result.shape).toEqual(shape);
expect(result.dtype).toEqual(updates.dtype);
test_util_1.expectArraysClose(result, [100, 0, 102, 0, 101]);
});
it('should work for high rank updates', function () {
var indices = tf.tensor2d([0, 2], [2, 1], 'int32');
var updates = tf.tensor3d([
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8
], [2, 4, 4], 'float32');
var shape = [4, 4, 4];
var result = tf.scatterND(indices, updates, shape);
expect(result.shape).toEqual(shape);
expect(result.dtype).toEqual(updates.dtype);
test_util_1.expectArraysClose(result, [
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7,
8, 8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
]);
});
it('should work for high rank indices', function () {
var indices = tf.tensor2d([0, 2, 0, 1], [2, 2], 'int32');
var updates = tf.tensor1d([10, 20], 'float32');
var shape = [3, 3];
var result = tf.scatterND(indices, updates, shape);
expect(result.shape).toEqual(shape);
expect(result.dtype).toEqual(updates.dtype);
test_util_1.expectArraysClose(result, [0, 20, 10, 0, 0, 0, 0, 0, 0]);
});
it('should work for high rank indices and update', function () {
var indices = tf.tensor2d([1, 0, 0, 1, 0, 1], [3, 2], 'int32');
var updates = tf.ones([3, 256], 'float32');
var shape = [2, 2, 256];
var result = tf.scatterND(indices, updates, shape);
expect(result.shape).toEqual(shape);
expect(result.dtype).toEqual(updates.dtype);
});
it('should sum the duplicated indices', function () {
var indices = tf.tensor1d([0, 4, 2, 1, 3, 0], 'int32');
var updates = tf.tensor1d([10, 20, 30, 40, 50, 60], 'float32');
var shape = [8];
var result = tf.scatterND(indices, updates, shape);
expect(result.shape).toEqual(shape);
expect(result.dtype).toEqual(updates.dtype);
test_util_1.expectArraysClose(result, [70, 40, 30, 50, 20, 0, 0, 0]);
});
it('should work for tensorLike input', function () {
var indices = [0, 4, 2];
var updates = [100, 101, 102];
var shape = [5];
var result = tf.scatterND(indices, updates, shape);
expect(result.shape).toEqual(shape);
expect(result.dtype).toEqual('float32');
test_util_1.expectArraysClose(result, [100, 0, 102, 0, 101]);
});
it('should throw error when indices type is not int32', function () {
var indices = tf.tensor2d([0, 2, 0, 1], [2, 2], 'float32');
var updates = tf.tensor1d([10, 20], 'float32');
var shape = [3, 3];
expect(function () { return tf.scatterND(indices, updates, shape); }).toThrow();
});
it('should throw error when indices and update mismatch', function () {
var indices = tf.tensor2d([0, 4, 2], [3, 1], 'int32');
var updates = tf.tensor2d([100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004], [3, 4], 'float32');
var shape = [5, 3];
expect(function () { return tf.scatterND(indices, updates, shape); }).toThrow();
});
it('should throw error when indices and update count mismatch', function () {
var indices = tf.tensor2d([0, 4, 2], [3, 1], 'int32');
var updates = tf.tensor2d([100, 101, 102, 10000, 10001, 10002], [2, 3], 'float32');
var shape = [5, 3];
expect(function () { return tf.scatterND(indices, updates, shape); }).toThrow();
});
it('should throw error when indices are scalar', function () {
var indices = tf.scalar(1, 'int32');
var updates = tf.tensor2d([100, 101, 102, 10000, 10001, 10002], [2, 3], 'float32');
var shape = [5, 3];
expect(function () { return tf.scatterND(indices, updates, shape); }).toThrow();
});
it('should throw error when update is scalar', function () {
var indices = tf.tensor2d([0, 4, 2], [3, 1], 'int32');
var updates = tf.scalar(1, 'float32');
var shape = [5, 3];
expect(function () { return tf.scatterND(indices, updates, shape); }).toThrow();
});
});
jasmine_util_1.describeWithFlags('scatterND CPU', test_util_1.CPU_ENVS, function () {
it('should throw error when index out of range', function () {
var indices = tf.tensor2d([0, 4, 99], [3, 1], 'int32');
var updates = tf.tensor2d([100, 101, 102, 777, 778, 779, 10000, 10001, 10002], [3, 3], 'float32');
var shape = [5, 3];
expect(function () { return tf.scatterND(indices, updates, shape); }).toThrow();
});
it('should throw error when indices has wrong dimension', function () {
var indices = tf.tensor2d([0, 4, 99], [3, 1], 'int32');
var updates = tf.tensor2d([100, 101, 102, 777, 778, 779, 10000, 10001, 10002], [3, 3], 'float32');
var shape = [2, 3];
expect(function () { return tf.scatterND(indices, updates, shape); }).toThrow();
});
});
//# sourceMappingURL=scatter_nd_test.js.map