@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
103 lines • 17.4 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '../index';
import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
import { expectArraysClose } from '../test_util';
describeWithFlags('booleanMaskAsync', ALL_ENVS, () => {
it('1d array, 1d mask, default axis', async () => {
const array = tf.tensor1d([1, 2, 3]);
const mask = tf.tensor1d([1, 0, 1], 'bool');
const result = await tf.booleanMaskAsync(array, mask);
expect(result.shape).toEqual([2]);
expect(result.dtype).toBe('float32');
expectArraysClose(await result.data(), [1, 3]);
});
it('2d array, 1d mask, default axis', async () => {
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
const mask = tf.tensor1d([1, 0, 1], 'bool');
const result = await tf.booleanMaskAsync(array, mask);
expect(result.shape).toEqual([2, 2]);
expect(result.dtype).toBe('float32');
expectArraysClose(await result.data(), [1, 2, 5, 6]);
});
it('2d array, 2d mask, default axis', async () => {
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
const mask = tf.tensor2d([1, 0, 1, 0, 1, 0], [3, 2], 'bool');
const result = await tf.booleanMaskAsync(array, mask);
expect(result.shape).toEqual([3]);
expect(result.dtype).toBe('float32');
expectArraysClose(await result.data(), [1, 3, 5]);
});
it('2d array, 1d mask, axis=1', async () => {
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
const mask = tf.tensor1d([0, 1], 'bool');
const axis = 1;
const result = await tf.booleanMaskAsync(array, mask, axis);
expect(result.shape).toEqual([3, 1]);
expect(result.dtype).toBe('float32');
expectArraysClose(await result.data(), [2, 4, 6]);
});
it('accepts tensor-like object as array or mask', async () => {
const array = [[1, 2], [3, 4], [5, 6]];
const mask = [1, 0, 1];
const result = await tf.booleanMaskAsync(array, mask);
expect(result.shape).toEqual([2, 2]);
expect(result.dtype).toBe('float32');
expectArraysClose(await result.data(), [1, 2, 5, 6]);
});
it('ensure no memory leak', async () => {
const numTensorsBefore = tf.memory().numTensors;
const array = tf.tensor1d([1, 2, 3]);
const mask = tf.tensor1d([1, 0, 1], 'bool');
const result = await tf.booleanMaskAsync(array, mask);
expect(result.shape).toEqual([2]);
expect(result.dtype).toBe('float32');
expectArraysClose(await result.data(), [1, 3]);
array.dispose();
mask.dispose();
result.dispose();
const numTensorsAfter = tf.memory().numTensors;
expect(numTensorsAfter).toBe(numTensorsBefore);
});
it('should throw if mask is scalar', async () => {
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
const mask = tf.scalar(1, 'bool');
let errorMessage = 'No error thrown.';
try {
await tf.booleanMaskAsync(array, mask);
}
catch (error) {
errorMessage = error.message;
}
expect(errorMessage).toBe('mask cannot be scalar');
});
it('should throw if array and mask shape miss match', async () => {
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
const mask = tf.tensor2d([1, 0], [1, 2], 'bool');
let errorMessage = 'No error thrown.';
try {
await tf.booleanMaskAsync(array, mask);
}
catch (error) {
errorMessage = error.message;
}
expect(errorMessage)
.toBe(`mask's shape must match the first K ` +
`dimensions of tensor's shape, Shapes 3,2 and 1,2 must match`);
});
});
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"boolean_mask_test.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/boolean_mask_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAC5D,OAAO,EAAC,iBAAiB,EAAC,MAAM,cAAc,CAAC;AAE/C,iBAAiB,CAAC,kBAAkB,EAAE,QAAQ,EAAE,GAAG,EAAE;IACnD,EAAE,CAAC,iCAAiC,EAAE,KAAK,IAAI,EAAE;QAC/C,MAAM,KAAK,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACrC,MAAM,IAAI,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QAC5C,MAAM,MAAM,GAAG,MAAM,EAAE,CAAC,gBAAgB,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;QACtD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACrC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iCAAiC,EAAE,KAAK,IAAI,EAAE;QAC/C,MAAM,KAAK,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,MAAM,IAAI,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QAC5C,MAAM,MAAM,GAAG,MAAM,EAAE,CAAC,gBAAgB,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;QACtD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACrC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACvD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iCAAiC,EAAE,KAAK,IAAI,EAAE;QAC/C,MAAM,KAAK,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,MAAM,IAAI,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QAC7D,MAAM,MAAM,GAAG,MAAM,EAAE,CAAC,gBAAgB,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;QACtD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACrC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,2BAA2B,EAAE,KAAK,IAAI,EAAE;QACzC,MAAM,KAAK,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,MAAM,IAAI,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QACzC,MAAM,IAAI,GAAG,CAAC,CAAC;QACf,MAAM,MAAM,GAAG,MAAM,EAAE,CAAC,gBAAgB,CAAC,KAAK,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC;QAC5D,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACrC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,6CAA6C,EAAE,KAAK,IAAI,EAAE;QAC3D,MAAM,KAAK,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACvC,MAAM,IAAI,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QACvB,MAAM,MAAM,GAAG,MAAM,EAAE,CAAC,gBAAgB,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;QACtD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACrC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACvD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,uBAAuB,EAAE,KAAK,IAAI,EAAE;QACrC,MAAM,gBAAgB,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,UAAU,CAAC;QAEhD,MAAM,KAAK,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACrC,MAAM,IAAI,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QAE5C,MAAM,MAAM,GAAG,MAAM,EAAE,CAAC,gBAAgB,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;QACtD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACrC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC/C,KAAK,CAAC,OAAO,EAAE,CAAC;QAChB,IAAI,CAAC,OAAO,EAAE,CAAC;QACf,MAAM,CAAC,OAAO,EAAE,CAAC;QAEjB,MAAM,eAAe,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,UAAU,CAAC;QAC/C,MAAM,CAAC,eAAe,CAAC,CAAC,IAAI,CAAC,gBAAgB,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,gCAAgC,EAAE,KAAK,IAAI,EAAE;QAC9C,MAAM,KAAK,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,MAAM,IAAI,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QAClC,IAAI,YAAY,GAAG,kBAAkB,CAAC;QACtC,IAAI;YACF,MAAM,EAAE,CAAC,gBAAgB,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SACxC;QAAC,OAAO,KAAK,EAAE;YACd,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC;SAC9B;QACD,MAAM,CAAC,YAAY,CAAC,CAAC,IAAI,CAAC,uBAAuB,CAAC,CAAC;IACrD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iDAAiD,EAAE,KAAK,IAAI,EAAE;QAC/D,MAAM,KAAK,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,MAAM,IAAI,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QACjD,IAAI,YAAY,GAAG,kBAAkB,CAAC;QACtC,IAAI;YACF,MAAM,EAAE,CAAC,gBAAgB,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SACxC;QAAC,OAAO,KAAK,EAAE;YACd,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC;SAC9B;QACD,MAAM,CAAC,YAAY,CAAC;aACf,IAAI,CACD,sCAAsC;YACtC,6DAA6D,CAAC,CAAC;IACzE,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../index';\nimport {ALL_ENVS, describeWithFlags} from '../jasmine_util';\nimport {expectArraysClose} from '../test_util';\n\ndescribeWithFlags('booleanMaskAsync', ALL_ENVS, () => {\n  it('1d array, 1d mask, default axis', async () => {\n    const array = tf.tensor1d([1, 2, 3]);\n    const mask = tf.tensor1d([1, 0, 1], 'bool');\n    const result = await tf.booleanMaskAsync(array, mask);\n    expect(result.shape).toEqual([2]);\n    expect(result.dtype).toBe('float32');\n    expectArraysClose(await result.data(), [1, 3]);\n  });\n\n  it('2d array, 1d mask, default axis', async () => {\n    const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);\n    const mask = tf.tensor1d([1, 0, 1], 'bool');\n    const result = await tf.booleanMaskAsync(array, mask);\n    expect(result.shape).toEqual([2, 2]);\n    expect(result.dtype).toBe('float32');\n    expectArraysClose(await result.data(), [1, 2, 5, 6]);\n  });\n\n  it('2d array, 2d mask, default axis', async () => {\n    const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);\n    const mask = tf.tensor2d([1, 0, 1, 0, 1, 0], [3, 2], 'bool');\n    const result = await tf.booleanMaskAsync(array, mask);\n    expect(result.shape).toEqual([3]);\n    expect(result.dtype).toBe('float32');\n    expectArraysClose(await result.data(), [1, 3, 5]);\n  });\n\n  it('2d array, 1d mask, axis=1', async () => {\n    const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);\n    const mask = tf.tensor1d([0, 1], 'bool');\n    const axis = 1;\n    const result = await tf.booleanMaskAsync(array, mask, axis);\n    expect(result.shape).toEqual([3, 1]);\n    expect(result.dtype).toBe('float32');\n    expectArraysClose(await result.data(), [2, 4, 6]);\n  });\n\n  it('accepts tensor-like object as array or mask', async () => {\n    const array = [[1, 2], [3, 4], [5, 6]];\n    const mask = [1, 0, 1];\n    const result = await tf.booleanMaskAsync(array, mask);\n    expect(result.shape).toEqual([2, 2]);\n    expect(result.dtype).toBe('float32');\n    expectArraysClose(await result.data(), [1, 2, 5, 6]);\n  });\n\n  it('ensure no memory leak', async () => {\n    const numTensorsBefore = tf.memory().numTensors;\n\n    const array = tf.tensor1d([1, 2, 3]);\n    const mask = tf.tensor1d([1, 0, 1], 'bool');\n\n    const result = await tf.booleanMaskAsync(array, mask);\n    expect(result.shape).toEqual([2]);\n    expect(result.dtype).toBe('float32');\n    expectArraysClose(await result.data(), [1, 3]);\n    array.dispose();\n    mask.dispose();\n    result.dispose();\n\n    const numTensorsAfter = tf.memory().numTensors;\n    expect(numTensorsAfter).toBe(numTensorsBefore);\n  });\n\n  it('should throw if mask is scalar', async () => {\n    const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);\n    const mask = tf.scalar(1, 'bool');\n    let errorMessage = 'No error thrown.';\n    try {\n      await tf.booleanMaskAsync(array, mask);\n    } catch (error) {\n      errorMessage = error.message;\n    }\n    expect(errorMessage).toBe('mask cannot be scalar');\n  });\n\n  it('should throw if array and mask shape miss match', async () => {\n    const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);\n    const mask = tf.tensor2d([1, 0], [1, 2], 'bool');\n    let errorMessage = 'No error thrown.';\n    try {\n      await tf.booleanMaskAsync(array, mask);\n    } catch (error) {\n      errorMessage = error.message;\n    }\n    expect(errorMessage)\n        .toBe(\n            `mask's shape must match the first K ` +\n            `dimensions of tensor's shape, Shapes 3,2 and 1,2 must match`);\n  });\n});\n"]}