UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

146 lines 16.1 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Testing utilities. */ import { memory, Tensor, test_util, util } from '@tensorflow/tfjs-core'; // tslint:disable-next-line: no-imports-from-dist import { ALL_ENVS, describeWithFlags } from '@tensorflow/tfjs-core/dist/jasmine_util'; import { ValueError } from '../errors'; /** * Expect values are close between a Tensor or number array. * @param actual * @param expected */ export function expectTensorsClose(actual, expected, epsilon) { if (actual == null) { throw new ValueError('First argument to expectTensorsClose() is not defined.'); } if (expected == null) { throw new ValueError('Second argument to expectTensorsClose() is not defined.'); } if (actual instanceof Tensor && expected instanceof Tensor) { if (actual.dtype !== expected.dtype) { throw new Error(`Data types do not match. Actual: '${actual.dtype}'. ` + `Expected: '${expected.dtype}'`); } if (!util.arraysEqual(actual.shape, expected.shape)) { throw new Error(`Shapes do not match. Actual: [${actual.shape}]. ` + `Expected: [${expected.shape}].`); } } const actualData = actual instanceof Tensor ? actual.dataSync() : actual; const expectedData = expected instanceof Tensor ? expected.dataSync() : expected; test_util.expectArraysClose(actualData, expectedData, epsilon); } /** * Expect values are not close between a Tensor or number array. * @param t1 * @param t2 */ export function expectTensorsNotClose(t1, t2, epsilon) { try { expectTensorsClose(t1, t2, epsilon); } catch (error) { return; } throw new Error('The two values are close at all elements.'); } /** * Expect values in array are within a specified range, boundaries inclusive. * @param actual * @param expected */ export function expectTensorsValuesInRange(actual, low, high) { if (actual == null) { throw new ValueError('First argument to expectTensorsClose() is not defined.'); } test_util.expectValuesInRange(actual.dataSync(), low, high); } /** * Describe tests to be run on CPU and GPU. * @param testName * @param tests */ export function describeMathCPUAndGPU(testName, tests) { describeWithFlags(testName, ALL_ENVS, () => { tests(); }); } /** * Describe tests to be run on CPU and GPU WebGL2. * @param testName * @param tests */ export function describeMathCPUAndWebGL2(testName, tests) { describeWithFlags(testName, { predicate: testEnv => (testEnv.flags == null || testEnv.flags['WEBGL_VERSION'] === 2) }, () => { tests(); }); } /** * Describe tests to be run on CPU only. * @param testName * @param tests */ export function describeMathCPU(testName, tests) { describeWithFlags(testName, { predicate: testEnv => testEnv.backendName === 'cpu' }, () => { tests(); }); } /** * Describe tests to be run on GPU only. * @param testName * @param tests */ export function describeMathGPU(testName, tests) { describeWithFlags(testName, { predicate: testEnv => testEnv.backendName === 'webgl' }, () => { tests(); }); } /** * Describe tests to be run on WebGL2 GPU only. * @param testName * @param tests */ export function describeMathWebGL2(testName, tests) { describeWithFlags(testName, { predicate: testEnv => testEnv.backendName === 'webgl' && (testEnv.flags == null || testEnv.flags['WEBGL_VERSION'] === 2) }, () => { tests(); }); } /** * Check that a function only generates the expected number of new Tensors. * * The test function is called twice, once to prime any regular constants and * once to ensure that additional copies aren't created/tensors aren't leaked. * * @param testFunc A fully curried (zero arg) version of the function to test. * @param numNewTensors The expected number of new Tensors that should exist. */ export function expectNoLeakedTensors( // tslint:disable-next-line:no-any testFunc, numNewTensors) { testFunc(); const numTensorsBefore = memory().numTensors; testFunc(); const numTensorsAfter = memory().numTensors; const actualNewTensors = numTensorsAfter - numTensorsBefore; if (actualNewTensors !== numNewTensors) { throw new ValueError(`Created an unexpected number of new ` + `Tensors. Expected: ${numNewTensors}, created : ${actualNewTensors}. ` + `Please investigate the discrepency and/or use tidy.`); } } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"test_utils.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/utils/test_utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,OAAO,EAAC,MAAM,EAAE,MAAM,EAAE,SAAS,EAAE,IAAI,EAAC,MAAM,uBAAuB,CAAC;AACtE,iDAAiD;AACjD,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,yCAAyC,CAAC;AAEpF,OAAO,EAAC,UAAU,EAAC,MAAM,WAAW,CAAC;AAErC;;;;GAIG;AACH,MAAM,UAAU,kBAAkB,CAC9B,MAAuB,EAAE,QAAyB,EAAE,OAAgB;IACtE,IAAI,MAAM,IAAI,IAAI,EAAE;QAClB,MAAM,IAAI,UAAU,CAChB,wDAAwD,CAAC,CAAC;KAC/D;IACD,IAAI,QAAQ,IAAI,IAAI,EAAE;QACpB,MAAM,IAAI,UAAU,CAChB,yDAAyD,CAAC,CAAC;KAChE;IACD,IAAI,MAAM,YAAY,MAAM,IAAI,QAAQ,YAAY,MAAM,EAAE;QAC1D,IAAI,MAAM,CAAC,KAAK,KAAK,QAAQ,CAAC,KAAK,EAAE;YACnC,MAAM,IAAI,KAAK,CACX,qCAAqC,MAAM,CAAC,KAAK,KAAK;gBACtD,cAAc,QAAQ,CAAC,KAAK,GAAG,CAAC,CAAC;SACtC;QACD,IAAI,CAAC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,KAAK,EAAE,QAAQ,CAAC,KAAK,CAAC,EAAE;YACnD,MAAM,IAAI,KAAK,CACX,iCAAiC,MAAM,CAAC,KAAK,KAAK;gBAClD,cAAc,QAAQ,CAAC,KAAK,IAAI,CAAC,CAAC;SACvC;KACF;IACD,MAAM,UAAU,GAAG,MAAM,YAAY,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,MAAM,CAAC;IACzE,MAAM,YAAY,GACd,QAAQ,YAAY,MAAM,CAAC,CAAC,CAAC,QAAQ,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,QAAQ,CAAC;IAChE,SAAS,CAAC,iBAAiB,CAAC,UAAU,EAAE,YAAY,EAAE,OAAO,CAAC,CAAC;AACjE,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,qBAAqB,CACnC,EAAmB,EAAE,EAAmB,EAAE,OAAgB;IAC5D,IAAI;QACF,kBAAkB,CAAC,EAAE,EAAE,EAAE,EAAE,OAAO,CAAC,CAAC;KACrC;IAAC,OAAO,KAAK,EAAE;QACd,OAAO;KACR;IACD,MAAM,IAAI,KAAK,CAAC,2CAA2C,CAAC,CAAC;AAC7D,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,0BAA0B,CACtC,MAAc,EAAE,GAAW,EAAE,IAAY;IAC3C,IAAI,MAAM,IAAI,IAAI,EAAE;QAClB,MAAM,IAAI,UAAU,CAChB,wDAAwD,CAAC,CAAC;KAC/D;IACD,SAAS,CAAC,mBAAmB,CAAC,MAAM,CAAC,QAAQ,EAAE,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC;AAC9D,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,qBAAqB,CAAC,QAAgB,EAAE,KAAiB;IACvE,iBAAiB,CAAC,QAAQ,EAAE,QAAQ,EAAE,GAAG,EAAE;QACzC,KAAK,EAAE,CAAC;IACV,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,wBAAwB,CAAC,QAAgB,EAAE,KAAiB;IAC1E,iBAAiB,CACb,QAAQ,EAAE;QACR,SAAS,EAAE,OAAO,CAAC,EAAE,CACjB,CAAC,OAAO,CAAC,KAAK,IAAI,IAAI,IAAI,OAAO,CAAC,KAAK,CAAC,eAAe,CAAC,KAAK,CAAC,CAAC;KACpE,EACD,GAAG,EAAE;QACH,KAAK,EAAE,CAAC;IACV,CAAC,CAAC,CAAC;AACT,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,eAAe,CAAC,QAAgB,EAAE,KAAiB;IACjE,iBAAiB,CACb,QAAQ,EAAE,EAAC,SAAS,EAAE,OAAO,CAAC,EAAE,CAAC,OAAO,CAAC,WAAW,KAAK,KAAK,EAAC,EAAE,GAAG,EAAE;QACpE,KAAK,EAAE,CAAC;IACV,CAAC,CAAC,CAAC;AACT,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,eAAe,CAAC,QAAgB,EAAE,KAAiB;IACjE,iBAAiB,CACb,QAAQ,EAAE,EAAC,SAAS,EAAE,OAAO,CAAC,EAAE,CAAC,OAAO,CAAC,WAAW,KAAK,OAAO,EAAC,EAAE,GAAG,EAAE;QACtE,KAAK,EAAE,CAAC;IACV,CAAC,CAAC,CAAC;AACT,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,kBAAkB,CAAC,QAAgB,EAAE,KAAiB;IACpE,iBAAiB,CACb,QAAQ,EAAE;QACR,SAAS,EAAE,OAAO,CAAC,EAAE,CAAC,OAAO,CAAC,WAAW,KAAK,OAAO;YACjD,CAAC,OAAO,CAAC,KAAK,IAAI,IAAI,IAAI,OAAO,CAAC,KAAK,CAAC,eAAe,CAAC,KAAK,CAAC,CAAC;KAEpE,EACD,GAAG,EAAE;QACH,KAAK,EAAE,CAAC;IACV,CAAC,CAAC,CAAC;AACT,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,UAAU,qBAAqB;AACjC,kCAAkC;AAClC,QAAmB,EAAE,aAAqB;IAC5C,QAAQ,EAAE,CAAC;IACX,MAAM,gBAAgB,GAAG,MAAM,EAAE,CAAC,UAAU,CAAC;IAC7C,QAAQ,EAAE,CAAC;IACX,MAAM,eAAe,GAAG,MAAM,EAAE,CAAC,UAAU,CAAC;IAC5C,MAAM,gBAAgB,GAAG,eAAe,GAAG,gBAAgB,CAAC;IAC5D,IAAI,gBAAgB,KAAK,aAAa,EAAE;QACtC,MAAM,IAAI,UAAU,CAChB,sCAAsC;YACtC,uBAAuB,aAAa,eAChC,gBAAgB,IAAI;YACxB,qDAAqD,CAAC,CAAC;KAC5D;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * Testing utilities.\n */\n\nimport {memory, Tensor, test_util, util} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport {ALL_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util';\n\nimport {ValueError} from '../errors';\n\n/**\n * Expect values are close between a Tensor or number array.\n * @param actual\n * @param expected\n */\nexport function expectTensorsClose(\n    actual: Tensor|number[], expected: Tensor|number[], epsilon?: number) {\n  if (actual == null) {\n    throw new ValueError(\n        'First argument to expectTensorsClose() is not defined.');\n  }\n  if (expected == null) {\n    throw new ValueError(\n        'Second argument to expectTensorsClose() is not defined.');\n  }\n  if (actual instanceof Tensor && expected instanceof Tensor) {\n    if (actual.dtype !== expected.dtype) {\n      throw new Error(\n          `Data types do not match. Actual: '${actual.dtype}'. ` +\n          `Expected: '${expected.dtype}'`);\n    }\n    if (!util.arraysEqual(actual.shape, expected.shape)) {\n      throw new Error(\n          `Shapes do not match. Actual: [${actual.shape}]. ` +\n          `Expected: [${expected.shape}].`);\n    }\n  }\n  const actualData = actual instanceof Tensor ? actual.dataSync() : actual;\n  const expectedData =\n      expected instanceof Tensor ? expected.dataSync() : expected;\n  test_util.expectArraysClose(actualData, expectedData, epsilon);\n}\n\n/**\n * Expect values are not close between a Tensor or number array.\n * @param t1\n * @param t2\n */\nexport function expectTensorsNotClose(\n  t1: Tensor|number[], t2: Tensor|number[], epsilon?: number) {\ntry {\n  expectTensorsClose(t1, t2, epsilon);\n} catch (error) {\n  return;\n}\nthrow new Error('The two values are close at all elements.');\n}\n\n/**\n * Expect values in array are within a specified range, boundaries inclusive.\n * @param actual\n * @param expected\n */\nexport function expectTensorsValuesInRange(\n    actual: Tensor, low: number, high: number) {\n  if (actual == null) {\n    throw new ValueError(\n        'First argument to expectTensorsClose() is not defined.');\n  }\n  test_util.expectValuesInRange(actual.dataSync(), low, high);\n}\n\n/**\n * Describe tests to be run on CPU and GPU.\n * @param testName\n * @param tests\n */\nexport function describeMathCPUAndGPU(testName: string, tests: () => void) {\n  describeWithFlags(testName, ALL_ENVS, () => {\n    tests();\n  });\n}\n\n/**\n * Describe tests to be run on CPU and GPU WebGL2.\n * @param testName\n * @param tests\n */\nexport function describeMathCPUAndWebGL2(testName: string, tests: () => void) {\n  describeWithFlags(\n      testName, {\n        predicate: testEnv =>\n            (testEnv.flags == null || testEnv.flags['WEBGL_VERSION'] === 2)\n      },\n      () => {\n        tests();\n      });\n}\n\n/**\n * Describe tests to be run on CPU only.\n * @param testName\n * @param tests\n */\nexport function describeMathCPU(testName: string, tests: () => void) {\n  describeWithFlags(\n      testName, {predicate: testEnv => testEnv.backendName === 'cpu'}, () => {\n        tests();\n      });\n}\n\n/**\n * Describe tests to be run on GPU only.\n * @param testName\n * @param tests\n */\nexport function describeMathGPU(testName: string, tests: () => void) {\n  describeWithFlags(\n      testName, {predicate: testEnv => testEnv.backendName === 'webgl'}, () => {\n        tests();\n      });\n}\n\n/**\n * Describe tests to be run on WebGL2 GPU only.\n * @param testName\n * @param tests\n */\nexport function describeMathWebGL2(testName: string, tests: () => void) {\n  describeWithFlags(\n      testName, {\n        predicate: testEnv => testEnv.backendName === 'webgl' &&\n            (testEnv.flags == null || testEnv.flags['WEBGL_VERSION'] === 2)\n\n      },\n      () => {\n        tests();\n      });\n}\n\n/**\n * Check that a function only generates the expected number of new Tensors.\n *\n * The test  function is called twice, once to prime any regular constants and\n * once to ensure that additional copies aren't created/tensors aren't leaked.\n *\n * @param testFunc A fully curried (zero arg) version of the function to test.\n * @param numNewTensors The expected number of new Tensors that should exist.\n */\nexport function expectNoLeakedTensors(\n    // tslint:disable-next-line:no-any\n    testFunc: () => any, numNewTensors: number) {\n  testFunc();\n  const numTensorsBefore = memory().numTensors;\n  testFunc();\n  const numTensorsAfter = memory().numTensors;\n  const actualNewTensors = numTensorsAfter - numTensorsBefore;\n  if (actualNewTensors !== numNewTensors) {\n    throw new ValueError(\n        `Created an unexpected number of new ` +\n        `Tensors.  Expected: ${numNewTensors}, created : ${\n            actualNewTensors}. ` +\n        `Please investigate the discrepency and/or use tidy.`);\n  }\n}\n"]}