UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

156 lines (138 loc) 5.17 kB
/** * @license * Copyright 2017 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {ENGINE} from './engine'; import {inferShape} from './tensor_util_env'; import {RecursiveArray, TensorLike, TypedArray} from './types'; import {arraysEqual, flatten, isString, isTypedArray} from './util'; const TEST_EPSILON_FLOAT32 = 1e-3; export const TEST_EPSILON_FLOAT16 = 1e-1; export function expectArraysClose( actual: TypedArray|number|RecursiveArray<number>, expected: TypedArray|number|RecursiveArray<number>, epsilon?: number) { if (epsilon == null) { epsilon = testEpsilon(); } return expectArraysPredicate( actual, expected, (a, b) => areClose(a as number, b as number, epsilon)); } export function testEpsilon() { return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 : TEST_EPSILON_FLOAT16; } function expectArraysPredicate( actual: TensorLike, expected: TensorLike, predicate: (a: {}, b: {}) => boolean) { let checkClassType = true; if (isTypedArray(actual) || isTypedArray(expected)) { checkClassType = false; } if (isTypedArray(actual) && isTypedArray(expected)) { checkClassType = true; } if (checkClassType) { const aType = actual.constructor.name; const bType = expected.constructor.name; if (aType !== bType) { throw new Error( `Arrays are of different type. Actual: ${aType}. ` + `Expected: ${bType}`); } } if (Array.isArray(actual) && Array.isArray(expected)) { const actualShape = inferShape(actual); const expectedShape = inferShape(expected); if (!arraysEqual(actualShape, expectedShape)) { throw new Error( `Arrays have different shapes. ` + `Actual: [${actualShape}]. Expected: [${expectedShape}]`); } } const actualFlat = isTypedArray(actual) ? actual : flatten(actual as RecursiveArray<number>); const expectedFlat = isTypedArray(expected) ? expected : flatten(expected as RecursiveArray<number>); if (actualFlat.length !== expectedFlat.length) { throw new Error( `Arrays have different lengths actual: ${actualFlat.length} vs ` + `expected: ${expectedFlat.length}.\n` + `Actual: ${actualFlat}.\n` + `Expected: ${expectedFlat}.`); } for (let i = 0; i < expectedFlat.length; ++i) { const a = actualFlat[i]; const e = expectedFlat[i]; if (!predicate(a, e)) { throw new Error( `Arrays differ: actual[${i}] = ${a}, expected[${i}] = ${e}.\n` + `Actual: ${actualFlat}.\n` + `Expected: ${expectedFlat}.`); } } } export interface DoneFn { (): void; fail: (message?: Error|string) => void; } export function expectPromiseToFail(fn: () => Promise<{}>, done: DoneFn): void { fn().then(() => done.fail(), () => done()); } export function expectArraysEqual(actual: TensorLike, expected: TensorLike) { const exp = typeof expected === 'string' || typeof expected === 'number' || typeof expected === 'boolean' ? [expected] as number[] : expected as number[]; if (isString(actual) || isString((actual as string[])[0]) || isString(expected) || isString((expected as string[])[0])) { // tslint:disable-next-line: triple-equals return expectArraysPredicate(actual, exp, (a, b) => a == b); } return expectArraysPredicate( actual, expected, (a, b) => areClose(a as number, b as number, 0)); } export function expectNumbersClose(a: number, e: number, epsilon?: number) { if (epsilon == null) { epsilon = testEpsilon(); } if (!areClose(a, e, epsilon)) { throw new Error(`Numbers differ: actual === ${a}, expected === ${e}`); } } function areClose(a: number, e: number, epsilon: number): boolean { if (!isFinite(a) && !isFinite(e)) { return true; } if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) { return false; } return true; } export function expectValuesInRange( actual: TypedArray|number[], low: number, high: number) { for (let i = 0; i < actual.length; i++) { if (actual[i] < low || actual[i] > high) { throw new Error( `Value out of range:${actual[i]} low: ${low}, high: ${high}`); } } } export function expectArrayBuffersEqual( actual: ArrayBuffer, expected: ArrayBuffer) { // Safari & Jasmine don't like comparing ArrayBuffers directly. Wrapping in // a Float32Array solves this issue. expect(new Float32Array(actual)).toEqual(new Float32Array(expected)); }