@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
156 lines (138 loc) • 5.17 kB
text/typescript
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {ENGINE} from './engine';
import {inferShape} from './tensor_util_env';
import {RecursiveArray, TensorLike, TypedArray} from './types';
import {arraysEqual, flatten, isString, isTypedArray} from './util';
const TEST_EPSILON_FLOAT32 = 1e-3;
export const TEST_EPSILON_FLOAT16 = 1e-1;
export function expectArraysClose(
actual: TypedArray|number|RecursiveArray<number>,
expected: TypedArray|number|RecursiveArray<number>, epsilon?: number) {
if (epsilon == null) {
epsilon = testEpsilon();
}
return expectArraysPredicate(
actual, expected, (a, b) => areClose(a as number, b as number, epsilon));
}
export function testEpsilon() {
return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 :
TEST_EPSILON_FLOAT16;
}
function expectArraysPredicate(
actual: TensorLike, expected: TensorLike,
predicate: (a: {}, b: {}) => boolean) {
let checkClassType = true;
if (isTypedArray(actual) || isTypedArray(expected)) {
checkClassType = false;
}
if (isTypedArray(actual) && isTypedArray(expected)) {
checkClassType = true;
}
if (checkClassType) {
const aType = actual.constructor.name;
const bType = expected.constructor.name;
if (aType !== bType) {
throw new Error(
`Arrays are of different type. Actual: ${aType}. ` +
`Expected: ${bType}`);
}
}
if (Array.isArray(actual) && Array.isArray(expected)) {
const actualShape = inferShape(actual);
const expectedShape = inferShape(expected);
if (!arraysEqual(actualShape, expectedShape)) {
throw new Error(
`Arrays have different shapes. ` +
`Actual: [${actualShape}]. Expected: [${expectedShape}]`);
}
}
const actualFlat =
isTypedArray(actual) ? actual : flatten(actual as RecursiveArray<number>);
const expectedFlat = isTypedArray(expected) ?
expected :
flatten(expected as RecursiveArray<number>);
if (actualFlat.length !== expectedFlat.length) {
throw new Error(
`Arrays have different lengths actual: ${actualFlat.length} vs ` +
`expected: ${expectedFlat.length}.\n` +
`Actual: ${actualFlat}.\n` +
`Expected: ${expectedFlat}.`);
}
for (let i = 0; i < expectedFlat.length; ++i) {
const a = actualFlat[i];
const e = expectedFlat[i];
if (!predicate(a, e)) {
throw new Error(
`Arrays differ: actual[${i}] = ${a}, expected[${i}] = ${e}.\n` +
`Actual: ${actualFlat}.\n` +
`Expected: ${expectedFlat}.`);
}
}
}
export interface DoneFn {
(): void;
fail: (message?: Error|string) => void;
}
export function expectPromiseToFail(fn: () => Promise<{}>, done: DoneFn): void {
fn().then(() => done.fail(), () => done());
}
export function expectArraysEqual(actual: TensorLike, expected: TensorLike) {
const exp = typeof expected === 'string' || typeof expected === 'number' ||
typeof expected === 'boolean' ?
[expected] as number[] :
expected as number[];
if (isString(actual) || isString((actual as string[])[0]) ||
isString(expected) || isString((expected as string[])[0])) {
// tslint:disable-next-line: triple-equals
return expectArraysPredicate(actual, exp, (a, b) => a == b);
}
return expectArraysPredicate(
actual, expected, (a, b) => areClose(a as number, b as number, 0));
}
export function expectNumbersClose(a: number, e: number, epsilon?: number) {
if (epsilon == null) {
epsilon = testEpsilon();
}
if (!areClose(a, e, epsilon)) {
throw new Error(`Numbers differ: actual === ${a}, expected === ${e}`);
}
}
function areClose(a: number, e: number, epsilon: number): boolean {
if (!isFinite(a) && !isFinite(e)) {
return true;
}
if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {
return false;
}
return true;
}
export function expectValuesInRange(
actual: TypedArray|number[], low: number, high: number) {
for (let i = 0; i < actual.length; i++) {
if (actual[i] < low || actual[i] > high) {
throw new Error(
`Value out of range:${actual[i]} low: ${low}, high: ${high}`);
}
}
}
export function expectArrayBuffersEqual(
actual: ArrayBuffer, expected: ArrayBuffer) {
// Safari & Jasmine don't like comparing ArrayBuffers directly. Wrapping in
// a Float32Array solves this issue.
expect(new Float32Array(actual)).toEqual(new Float32Array(expected));
}