UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

53 lines 11.8 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tf from '../index'; import { ALL_ENVS, describeWithFlags } from '../jasmine_util'; import { expectArraysClose } from '../test_util'; describeWithFlags('broadcastTo', ALL_ENVS, () => { it('[] -> [3,2]', async () => { const a = tf.scalar(4.2); const A = tf.tensor2d([[4.2, 4.2], [4.2, 4.2], [4.2, 4.2]]); expectArraysClose(await A.array(), await tf.broadcastTo(a, A.shape).array()); // test gradients const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), f = (a) => tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), h = (a) => a.mul(w).mean().asScalar(); const df = tf.grad(f), dh = tf.grad(h); expectArraysClose(await df(a).array(), await dh(a).array()); }); it('[2] -> [3,2]', async () => { const a = tf.tensor1d([1, 2]); const A = tf.tensor2d([[1, 2], [1, 2], [1, 2]]); expectArraysClose(await A.array(), await tf.broadcastTo(a, A.shape).array()); // test gradients const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), f = (a) => tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), h = (a) => a.mul(w).mean().asScalar(); const df = tf.grad(f), dh = tf.grad(h); expectArraysClose(await df(a).array(), await dh(a).array()); }); it('[3,1] -> [3,2]', async () => { const a = tf.tensor2d([[1], [2], [3]]); const A = tf.tensor2d([[1, 1], [2, 2], [3, 3]]); expectArraysClose(await A.array(), await tf.broadcastTo(a, A.shape).array()); // test gradients const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), f = (a) => tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), h = (a) => a.mul(w).mean().asScalar(); const df = tf.grad(f), dh = tf.grad(h); expectArraysClose(await df(a).array(), await dh(a).array()); }); it('should throw error when shape is not integer', () => { const a = tf.scalar(4.2); expect(() => tf.broadcastTo(a, [2, 2.22, 3.33])).toThrow(); }); }); //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYnJvYWRjYXN0X3RvX3Rlc3QuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy9icm9hZGNhc3RfdG9fdGVzdC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEtBQUssRUFBRSxNQUFNLFVBQVUsQ0FBQztBQUMvQixPQUFPLEVBQUMsUUFBUSxFQUFFLGlCQUFpQixFQUFDLE1BQU0saUJBQWlCLENBQUM7QUFFNUQsT0FBTyxFQUFDLGlCQUFpQixFQUFDLE1BQU0sY0FBYyxDQUFDO0FBRS9DLGlCQUFpQixDQUFDLGFBQWEsRUFBRSxRQUFRLEVBQUUsR0FBRyxFQUFFO0lBQzlDLEVBQUUsQ0FBQyxhQUFhLEVBQUUsS0FBSyxJQUFJLEVBQUU7UUFDM0IsTUFBTSxDQUFDLEdBQUcsRUFBRSxDQUFDLE1BQU0sQ0FBQyxHQUFHLENBQUMsQ0FBQztRQUN6QixNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxHQUFHLEVBQUUsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLEVBQUUsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLEVBQUUsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBRTVELGlCQUFpQixDQUNiLE1BQU0sQ0FBQyxDQUFDLEtBQUssRUFBRSxFQUFFLE1BQU0sRUFBRSxDQUFDLFdBQVcsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7UUFFL0QsaUJBQWlCO1FBQ2pCLE1BQU0sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDLEdBQUcsRUFBRSxHQUFHLENBQUMsRUFBRSxDQUFDLENBQUMsR0FBRyxFQUFFLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEdBQUcsRUFBRSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFDekQsQ0FBQyxHQUFHLENBQUMsQ0FBUyxFQUFFLEVBQUUsQ0FDZCxFQUFFLENBQUMsV0FBVyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDLElBQUksRUFBRSxDQUFDLFFBQVEsRUFBRSxFQUN2RCxDQUFDLEdBQUcsQ0FBQyxDQUFTLEVBQUUsRUFBRSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsSUFBSSxFQUFFLENBQUMsUUFBUSxFQUFFLENBQUM7UUFFcEQsTUFBTSxFQUFFLEdBQUcsRUFBRSxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxFQUFFLEdBQUcsRUFBRSxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUV2QyxpQkFBaUIsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsRUFBRSxNQUFNLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDO0lBQzlELENBQUMsQ0FBQyxDQUFDO0lBRUgsRUFBRSxDQUFDLGNBQWMsRUFBRSxLQUFLLElBQUksRUFBRTtRQUM1QixNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDOUIsTUFBTSxDQUFDLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUNoRCxpQkFBaUIsQ0FDYixNQUFNLENBQUMsQ0FBQyxLQUFLLEVBQUUsRUFBRSxNQUFNLEVBQUUsQ0FBQyxXQUFXLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDO1FBRS9ELGlCQUFpQjtRQUNqQixNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxHQUFHLEVBQUUsR0FBRyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEdBQUcsRUFBRSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsQ0FBQyxHQUFHLEVBQUUsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQ3pELENBQUMsR0FBRyxDQUFDLENBQVMsRUFBRSxFQUFFLENBQ2QsRUFBRSxDQUFDLFdBQVcsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsQ0FBQyxJQUFJLEVBQUUsQ0FBQyxRQUFRLEVBQUUsRUFDdkQsQ0FBQyxHQUFHLENBQUMsQ0FBUyxFQUFFLEVBQUUsQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDLElBQUksRUFBRSxDQUFDLFFBQVEsRUFBRSxDQUFDO1FBRXBELE1BQU0sRUFBRSxHQUFHLEVBQUUsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDLEVBQUUsRUFBRSxHQUFHLEVBQUUsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFFdkMsaUJBQWlCLENBQUMsTUFBTSxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsS0FBSyxFQUFFLEVBQUUsTUFBTSxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQztJQUM5RCxDQUFDLENBQUMsQ0FBQztJQUVILEVBQUUsQ0FBQyxnQkFBZ0IsRUFBRSxLQUFLLElBQUksRUFBRTtRQUM5QixNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQ3ZDLE1BQU0sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFFaEQsaUJBQWlCLENBQ2IsTUFBTSxDQUFDLENBQUMsS0FBSyxFQUFFLEVBQUUsTUFBTSxFQUFFLENBQUMsV0FBVyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQztRQUUvRCxpQkFBaUI7UUFDakIsTUFBTSxDQUFDLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsR0FBRyxFQUFFLEdBQUcsQ0FBQyxFQUFFLENBQUMsQ0FBQyxHQUFHLEVBQUUsQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLENBQUMsR0FBRyxFQUFFLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxFQUN6RCxDQUFDLEdBQUcsQ0FBQyxDQUFTLEVBQUUsRUFBRSxDQUNkLEVBQUUsQ0FBQyxXQUFXLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsSUFBSSxFQUFFLENBQUMsUUFBUSxFQUFFLEVBQ3ZELENBQUMsR0FBRyxDQUFDLENBQVMsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsQ0FBQyxJQUFJLEVBQUUsQ0FBQyxRQUFRLEVBQUUsQ0FBQztRQUVwRCxNQUFNLEVBQUUsR0FBRyxFQUFFLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxFQUFFLEVBQUUsR0FBRyxFQUFFLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBRXZDLGlCQUFpQixDQUFDLE1BQU0sRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxFQUFFLE1BQU0sRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7SUFDOUQsQ0FBQyxDQUFDLENBQUM7SUFFSCxFQUFFLENBQUMsOENBQThDLEVBQUUsR0FBRyxFQUFFO1FBQ3RELE1BQU0sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxNQUFNLENBQUMsR0FBRyxDQUFDLENBQUM7UUFDekIsTUFBTSxDQUFDLEdBQUcsRUFBRSxDQUFDLEVBQUUsQ0FBQyxXQUFXLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLElBQUksRUFBRSxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUMsT0FBTyxFQUFFLENBQUM7SUFDN0QsQ0FBQyxDQUFDLENBQUM7QUFDTCxDQUFDLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE4IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0ICogYXMgdGYgZnJvbSAnLi4vaW5kZXgnO1xuaW1wb3J0IHtBTExfRU5WUywgZGVzY3JpYmVXaXRoRmxhZ3N9IGZyb20gJy4uL2phc21pbmVfdXRpbCc7XG5pbXBvcnQge1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7ZXhwZWN0QXJyYXlzQ2xvc2V9IGZyb20gJy4uL3Rlc3RfdXRpbCc7XG5cbmRlc2NyaWJlV2l0aEZsYWdzKCdicm9hZGNhc3RUbycsIEFMTF9FTlZTLCAoKSA9PiB7XG4gIGl0KCdbXSAtPiBbMywyXScsIGFzeW5jICgpID0+IHtcbiAgICBjb25zdCBhID0gdGYuc2NhbGFyKDQuMik7XG4gICAgY29uc3QgQSA9IHRmLnRlbnNvcjJkKFtbNC4yLCA0LjJdLCBbNC4yLCA0LjJdLCBbNC4yLCA0LjJdXSk7XG5cbiAgICBleHBlY3RBcnJheXNDbG9zZShcbiAgICAgICAgYXdhaXQgQS5hcnJheSgpLCBhd2FpdCB0Zi5icm9hZGNhc3RUbyhhLCBBLnNoYXBlKS5hcnJheSgpKTtcblxuICAgIC8vIHRlc3QgZ3JhZGllbnRzXG4gICAgY29uc3QgdyA9IHRmLnRlbnNvcjJkKFtbNC43LCA0LjVdLCBbLTYuMSwgLTYuNl0sIFstOC4xLCAtMy40XV0pLFxuICAgICAgICAgIGYgPSAoYTogVGVuc29yKSA9PlxuICAgICAgICAgICAgICB0Zi5icm9hZGNhc3RUbyhhLCBBLnNoYXBlKS5tdWwodykubWVhbigpLmFzU2NhbGFyKCksXG4gICAgICAgICAgaCA9IChhOiBUZW5zb3IpID0+IGEubXVsKHcpLm1lYW4oKS5hc1NjYWxhcigpO1xuXG4gICAgY29uc3QgZGYgPSB0Zi5ncmFkKGYpLCBkaCA9IHRmLmdyYWQoaCk7XG5cbiAgICBleHBlY3RBcnJheXNDbG9zZShhd2FpdCBkZihhKS5hcnJheSgpLCBhd2FpdCBkaChhKS5hcnJheSgpKTtcbiAgfSk7XG5cbiAgaXQoJ1syXSAtPiBbMywyXScsIGFzeW5jICgpID0+IHtcbiAgICBjb25zdCBhID0gdGYudGVuc29yMWQoWzEsIDJdKTtcbiAgICBjb25zdCBBID0gdGYudGVuc29yMmQoW1sxLCAyXSwgWzEsIDJdLCBbMSwgMl1dKTtcbiAgICBleHBlY3RBcnJheXNDbG9zZShcbiAgICAgICAgYXdhaXQgQS5hcnJheSgpLCBhd2FpdCB0Zi5icm9hZGNhc3RUbyhhLCBBLnNoYXBlKS5hcnJheSgpKTtcblxuICAgIC8vIHRlc3QgZ3JhZGllbnRzXG4gICAgY29uc3QgdyA9IHRmLnRlbnNvcjJkKFtbNC43LCA0LjVdLCBbLTYuMSwgLTYuNl0sIFstOC4xLCAtMy40XV0pLFxuICAgICAgICAgIGYgPSAoYTogVGVuc29yKSA9PlxuICAgICAgICAgICAgICB0Zi5icm9hZGNhc3RUbyhhLCBBLnNoYXBlKS5tdWwodykubWVhbigpLmFzU2NhbGFyKCksXG4gICAgICAgICAgaCA9IChhOiBUZW5zb3IpID0+IGEubXVsKHcpLm1lYW4oKS5hc1NjYWxhcigpO1xuXG4gICAgY29uc3QgZGYgPSB0Zi5ncmFkKGYpLCBkaCA9IHRmLmdyYWQoaCk7XG5cbiAgICBleHBlY3RBcnJheXNDbG9zZShhd2FpdCBkZihhKS5hcnJheSgpLCBhd2FpdCBkaChhKS5hcnJheSgpKTtcbiAgfSk7XG5cbiAgaXQoJ1szLDFdIC0+IFszLDJdJywgYXN5bmMgKCkgPT4ge1xuICAgIGNvbnN0IGEgPSB0Zi50ZW5zb3IyZChbWzFdLCBbMl0sIFszXV0pO1xuICAgIGNvbnN0IEEgPSB0Zi50ZW5zb3IyZChbWzEsIDFdLCBbMiwgMl0sIFszLCAzXV0pO1xuXG4gICAgZXhwZWN0QXJyYXlzQ2xvc2UoXG4gICAgICAgIGF3YWl0IEEuYXJyYXkoKSwgYXdhaXQgdGYuYnJvYWRjYXN0VG8oYSwgQS5zaGFwZSkuYXJyYXkoKSk7XG5cbiAgICAvLyB0ZXN0IGdyYWRpZW50c1xuICAgIGNvbnN0IHcgPSB0Zi50ZW5zb3IyZChbWzQuNywgNC41XSwgWy02LjEsIC02LjZdLCBbLTguMSwgLTMuNF1dKSxcbiAgICAgICAgICBmID0gKGE6IFRlbnNvcikgPT5cbiAgICAgICAgICAgICAgdGYuYnJvYWRjYXN0VG8oYSwgQS5zaGFwZSkubXVsKHcpLm1lYW4oKS5hc1NjYWxhcigpLFxuICAgICAgICAgIGggPSAoYTogVGVuc29yKSA9PiBhLm11bCh3KS5tZWFuKCkuYXNTY2FsYXIoKTtcblxuICAgIGNvbnN0IGRmID0gdGYuZ3JhZChmKSwgZGggPSB0Zi5ncmFkKGgpO1xuXG4gICAgZXhwZWN0QXJyYXlzQ2xvc2UoYXdhaXQgZGYoYSkuYXJyYXkoKSwgYXdhaXQgZGgoYSkuYXJyYXkoKSk7XG4gIH0pO1xuXG4gIGl0KCdzaG91bGQgdGhyb3cgZXJyb3Igd2hlbiBzaGFwZSBpcyBub3QgaW50ZWdlcicsICgpID0+IHtcbiAgICBjb25zdCBhID0gdGYuc2NhbGFyKDQuMik7XG4gICAgZXhwZWN0KCgpID0+IHRmLmJyb2FkY2FzdFRvKGEsIFsyLCAyLjIyLCAzLjMzXSkpLnRvVGhyb3coKTtcbiAgfSk7XG59KTtcbiJdfQ==