@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
72 lines • 10.1 kB
JavaScript
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { ENGINE } from '../engine';
import { Tile } from '../kernel_names';
import { convertToTensor } from '../tensor_util_env';
import { assertNonNegativeIntegerDimensions } from '../util_base';
import { clone } from './clone';
import { op } from './operation';
import { reshape } from './reshape';
/**
* Broadcast an array to a compatible shape NumPy-style.
*
* The tensor's shape is compared to the broadcast shape from end to beginning.
* Ones are prepended to the tensor's shape until it has the same length as
* the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
* already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
* the input tensor is tiled N times along that axis (using tf.tile).
*
* @param input The tensor that is to be broadcasted.
* @param shape The input is to be broadcast to this shape.
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function broadcastTo_(x, shape) {
let input = convertToTensor(x, 'broadcastTo', 'x');
const xShape = input.shape;
assertNonNegativeIntegerDimensions(shape);
if (shape.length < input.rank) {
throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`);
}
if (shape.length > input.rank) {
const newShape = input.shape.slice();
while (newShape.length < shape.length) {
newShape.unshift(1);
}
input = reshape(input, newShape);
}
const inputShape = input.shape;
const reps = Array.from(shape);
for (let i = shape.length - 1; i >= 0; i--) {
if (inputShape[i] === shape[i]) {
reps[i] = 1;
}
else if (input.shape[i] !== 1) {
throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
}
}
const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
if (axes.length === 0) {
return clone(input);
}
// TODO call broadcastTo kernel directly once backends implement broadcstTo
const inputs = { x: input };
const attrs = { reps };
return ENGINE.runKernel(Tile, inputs, attrs);
}
export const broadcastTo = /* @__PURE__ */ op({ broadcastTo_ });
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYnJvYWRjYXN0X3RvLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYnJvYWRjYXN0X3RvLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDakMsT0FBTyxFQUFDLElBQUksRUFBd0IsTUFBTSxpQkFBaUIsQ0FBQztBQUk1RCxPQUFPLEVBQUMsZUFBZSxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFFbkQsT0FBTyxFQUFDLGtDQUFrQyxFQUFDLE1BQU0sY0FBYyxDQUFDO0FBRWhFLE9BQU8sRUFBQyxLQUFLLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFDOUIsT0FBTyxFQUFDLEVBQUUsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUMvQixPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBRWxDOzs7Ozs7Ozs7Ozs7O0dBYUc7QUFDSCxTQUFTLFlBQVksQ0FDakIsQ0FBb0IsRUFBRSxLQUFrQjtJQUMxQyxJQUFJLEtBQUssR0FBRyxlQUFlLENBQUMsQ0FBQyxFQUFFLGFBQWEsRUFBRSxHQUFHLENBQUMsQ0FBQztJQUNuRCxNQUFNLE1BQU0sR0FBRyxLQUFLLENBQUMsS0FBSyxDQUFDO0lBRTNCLGtDQUFrQyxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBRTFDLElBQUksS0FBSyxDQUFDLE1BQU0sR0FBRyxLQUFLLENBQUMsSUFBSSxFQUFFO1FBQzdCLE1BQU0sSUFBSSxLQUFLLENBQUMsK0JBQStCLEtBQUssQ0FBQyxNQUFNLGlCQUN2RCxLQUFLLENBQUMsSUFBSSxHQUFHLENBQUMsQ0FBQztLQUNwQjtJQUVELElBQUksS0FBSyxDQUFDLE1BQU0sR0FBRyxLQUFLLENBQUMsSUFBSSxFQUFFO1FBQzdCLE1BQU0sUUFBUSxHQUFHLEtBQUssQ0FBQyxLQUFLLENBQUMsS0FBSyxFQUFFLENBQUM7UUFDckMsT0FBTyxRQUFRLENBQUMsTUFBTSxHQUFHLEtBQUssQ0FBQyxNQUFNLEVBQUU7WUFDckMsUUFBUSxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsQ0FBQztTQUNyQjtRQUNELEtBQUssR0FBRyxPQUFPLENBQUMsS0FBSyxFQUFFLFFBQVEsQ0FBQyxDQUFDO0tBQ2xDO0lBRUQsTUFBTSxVQUFVLEdBQUcsS0FBSyxDQUFDLEtBQUssQ0FBQztJQUMvQixNQUFNLElBQUksR0FBYSxLQUFLLENBQUMsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBQ3pDLEtBQUssSUFBSSxDQUFDLEdBQUcsS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLEVBQUUsQ0FBQyxJQUFJLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRTtRQUMxQyxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUU7WUFDOUIsSUFBSSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQztTQUNiO2FBQU0sSUFBSSxLQUFLLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsRUFBRTtZQUMvQixNQUFNLElBQUksS0FBSyxDQUNYLG1CQUFtQixNQUFNLDZCQUE2QixLQUFLLElBQUksQ0FBQyxDQUFDO1NBQ3RFO0tBQ0Y7SUFDRCxNQUFNLElBQUksR0FBRyxJQUFJLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQztJQUVwRSxJQUFJLElBQUksQ0FBQyxNQUFNLEtBQUssQ0FBQyxFQUFFO1FBQ3JCLE9BQU8sS0FBSyxDQUFDLEtBQUssQ0FBYyxDQUFDO0tBQ2xDO0lBRUQsMkVBQTJFO0lBQzNFLE1BQU0sTUFBTSxHQUFlLEVBQUMsQ0FBQyxFQUFFLEtBQUssRUFBQyxDQUFDO0lBQ3RDLE1BQU0sS0FBSyxHQUFjLEVBQUMsSUFBSSxFQUFDLENBQUM7SUFDaEMsT0FBTyxNQUFNLENBQUMsU0FBUyxDQUNuQixJQUFJLEVBQUUsTUFBbUMsRUFDekMsS0FBZ0MsQ0FBQyxDQUFDO0FBQ3hDLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxXQUFXLEdBQUcsZUFBZSxDQUFDLEVBQUUsQ0FBQyxFQUFDLFlBQVksRUFBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7RU5HSU5FfSBmcm9tICcuLi9lbmdpbmUnO1xuaW1wb3J0IHtUaWxlLCBUaWxlQXR0cnMsIFRpbGVJbnB1dHN9IGZyb20gJy4uL2tlcm5lbF9uYW1lcyc7XG5pbXBvcnQge05hbWVkQXR0ck1hcH0gZnJvbSAnLi4va2VybmVsX3JlZ2lzdHJ5JztcbmltcG9ydCB7VGVuc29yfSBmcm9tICcuLi90ZW5zb3InO1xuaW1wb3J0IHtOYW1lZFRlbnNvck1hcH0gZnJvbSAnLi4vdGVuc29yX3R5cGVzJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtSYW5rLCBTaGFwZU1hcCwgVGVuc29yTGlrZX0gZnJvbSAnLi4vdHlwZXMnO1xuaW1wb3J0IHthc3NlcnROb25OZWdhdGl2ZUludGVnZXJEaW1lbnNpb25zfSBmcm9tICcuLi91dGlsX2Jhc2UnO1xuXG5pbXBvcnQge2Nsb25lfSBmcm9tICcuL2Nsb25lJztcbmltcG9ydCB7b3B9IGZyb20gJy4vb3BlcmF0aW9uJztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9yZXNoYXBlJztcblxuLyoqXG4gKiBCcm9hZGNhc3QgYW4gYXJyYXkgdG8gYSBjb21wYXRpYmxlIHNoYXBlIE51bVB5LXN0eWxlLlxuICpcbiAqIFRoZSB0ZW5zb3IncyBzaGFwZSBpcyBjb21wYXJlZCB0byB0aGUgYnJvYWRjYXN0IHNoYXBlIGZyb20gZW5kIHRvIGJlZ2lubmluZy5cbiAqIE9uZXMgYXJlIHByZXBlbmRlZCB0byB0aGUgdGVuc29yJ3Mgc2hhcGUgdW50aWwgaXQgaGFzIHRoZSBzYW1lIGxlbmd0aCBhc1xuICogdGhlIGJyb2FkY2FzdCBzaGFwZS4gSWYgaW5wdXQuc2hhcGVbaV09PXNoYXBlW2ldLCB0aGUgKGkrMSktdGggYXhpcyBpc1xuICogYWxyZWFkeSBicm9hZGNhc3QtY29tcGF0aWJsZS4gSWYgaW5wdXQuc2hhcGVbaV09PTEgYW5kIHNoYXBlW2ldPT1OLCB0aGVuXG4gKiB0aGUgaW5wdXQgdGVuc29yIGlzIHRpbGVkIE4gdGltZXMgYWxvbmcgdGhhdCBheGlzICh1c2luZyB0Zi50aWxlKS5cbiAqXG4gKiBAcGFyYW0gaW5wdXQgVGhlIHRlbnNvciB0aGF0IGlzIHRvIGJlIGJyb2FkY2FzdGVkLlxuICogQHBhcmFtIHNoYXBlIFRoZSBpbnB1dCBpcyB0byBiZSBicm9hZGNhc3QgdG8gdGhpcyBzaGFwZS5cbiAqXG4gKiBAZG9jIHtoZWFkaW5nOiAnVGVuc29ycycsIHN1YmhlYWRpbmc6ICdUcmFuc2Zvcm1hdGlvbnMnfVxuICovXG5mdW5jdGlvbiBicm9hZGNhc3RUb188UiBleHRlbmRzIFJhbms+KFxuICAgIHg6IFRlbnNvcnxUZW5zb3JMaWtlLCBzaGFwZTogU2hhcGVNYXBbUl0pOiBUZW5zb3I8Uj4ge1xuICBsZXQgaW5wdXQgPSBjb252ZXJ0VG9UZW5zb3IoeCwgJ2Jyb2FkY2FzdFRvJywgJ3gnKTtcbiAgY29uc3QgeFNoYXBlID0gaW5wdXQuc2hhcGU7XG5cbiAgYXNzZXJ0Tm9uTmVnYXRpdmVJbnRlZ2VyRGltZW5zaW9ucyhzaGFwZSk7XG5cbiAgaWYgKHNoYXBlLmxlbmd0aCA8IGlucHV0LnJhbmspIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYGJyb2FkY2FzdFRvKCk6IHNoYXBlLmxlbmd0aD0ke3NoYXBlLmxlbmd0aH0gPCBpbnB1dC5yYW5rPSR7XG4gICAgICAgIGlucHV0LnJhbmt9LmApO1xuICB9XG5cbiAgaWYgKHNoYXBlLmxlbmd0aCA+IGlucHV0LnJhbmspIHtcbiAgICBjb25zdCBuZXdTaGFwZSA9IGlucHV0LnNoYXBlLnNsaWNlKCk7XG4gICAgd2hpbGUgKG5ld1NoYXBlLmxlbmd0aCA8IHNoYXBlLmxlbmd0aCkge1xuICAgICAgbmV3U2hhcGUudW5zaGlmdCgxKTtcbiAgICB9XG4gICAgaW5wdXQgPSByZXNoYXBlKGlucHV0LCBuZXdTaGFwZSk7XG4gIH1cblxuICBjb25zdCBpbnB1dFNoYXBlID0gaW5wdXQuc2hhcGU7XG4gIGNvbnN0IHJlcHM6IG51bWJlcltdID0gQXJyYXkuZnJvbShzaGFwZSk7XG4gIGZvciAobGV0IGkgPSBzaGFwZS5sZW5ndGggLSAxOyBpID49IDA7IGktLSkge1xuICAgIGlmIChpbnB1dFNoYXBlW2ldID09PSBzaGFwZVtpXSkge1xuICAgICAgcmVwc1tpXSA9IDE7XG4gICAgfSBlbHNlIGlmIChpbnB1dC5zaGFwZVtpXSAhPT0gMSkge1xuICAgICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICAgIGBicm9hZGNhc3RUbygpOiBbJHt4U2hhcGV9XSBjYW5ub3QgYmUgYnJvYWRjYXN0IHRvIFske3NoYXBlfV0uYCk7XG4gICAgfVxuICB9XG4gIGNvbnN0IGF4ZXMgPSByZXBzLm1hcCgobiwgaSkgPT4gbiA+IDEgPyBpIDogLTEpLmZpbHRlcihpID0+IGkgPj0gMCk7XG5cbiAgaWYgKGF4ZXMubGVuZ3RoID09PSAwKSB7XG4gICAgcmV0dXJuIGNsb25lKGlucHV0KSBhcyBUZW5zb3I8Uj47XG4gIH1cblxuICAvLyBUT0RPIGNhbGwgYnJvYWRjYXN0VG8ga2VybmVsIGRpcmVjdGx5IG9uY2UgYmFja2VuZHMgaW1wbGVtZW50IGJyb2FkY3N0VG9cbiAgY29uc3QgaW5wdXRzOiBUaWxlSW5wdXRzID0ge3g6IGlucHV0fTtcbiAgY29uc3QgYXR0cnM6IFRpbGVBdHRycyA9IHtyZXBzfTtcbiAgcmV0dXJuIEVOR0lORS5ydW5LZXJuZWwoXG4gICAgICBUaWxlLCBpbnB1dHMgYXMgdW5rbm93biBhcyBOYW1lZFRlbnNvck1hcCxcbiAgICAgIGF0dHJzIGFzIHVua25vd24gYXMgTmFtZWRBdHRyTWFwKTtcbn1cblxuZXhwb3J0IGNvbnN0IGJyb2FkY2FzdFRvID0gLyogQF9fUFVSRV9fICovIG9wKHticm9hZGNhc3RUb199KTtcbiJdfQ==