@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
105 lines • 3.73 kB
JavaScript
;
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
var util = require("../util");
/**
* Returns true if the axis specifies the inner most dimensions of the
* array.
*/
function axesAreInnerMostDims(axes, rank) {
for (var i = 0; i < axes.length; ++i) {
if (axes[axes.length - i - 1] !== rank - 1 - i) {
return false;
}
}
return true;
}
exports.axesAreInnerMostDims = axesAreInnerMostDims;
function combineLocations(outputLoc, reduceLoc, axes) {
var rank = outputLoc.length + reduceLoc.length;
var loc = [];
var outIdx = 0;
var reduceIdx = 0;
for (var dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
loc.push(outputLoc[outIdx++]);
}
else {
loc.push(reduceLoc[reduceIdx++]);
}
}
return loc;
}
exports.combineLocations = combineLocations;
function computeOutAndReduceShapes(aShape, axes) {
var outShape = [];
var rank = aShape.length;
for (var dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
outShape.push(aShape[dim]);
}
}
var reduceShape = axes.map(function (dim) { return aShape[dim]; });
return [outShape, reduceShape];
}
exports.computeOutAndReduceShapes = computeOutAndReduceShapes;
function expandShapeToKeepDim(shape, axes) {
var reduceSubShape = axes.map(function (x) { return 1; });
return combineLocations(shape, reduceSubShape, axes);
}
exports.expandShapeToKeepDim = expandShapeToKeepDim;
function assertAxesAreInnerMostDims(msg, axes, rank) {
util.assert(axesAreInnerMostDims(axes, rank), function () { return msg + " supports only inner-most axes for now. " +
("Got axes " + axes + " and rank-" + rank + " input."); });
}
exports.assertAxesAreInnerMostDims = assertAxesAreInnerMostDims;
/**
* Returns the axes permutation to be used with `tf.transpose`, if such
* permutation is necessary. Otherwise it returns null. This method is used by
* operations that operate only on inner-most axes.
*/
function getAxesPermutation(axes, rank) {
if (axesAreInnerMostDims(axes, rank)) {
return null;
}
var result = [];
for (var i = 0; i < rank; ++i) {
if (axes.indexOf(i) === -1) {
result.push(i);
}
}
axes.forEach(function (axis) { return result.push(axis); });
return result;
}
exports.getAxesPermutation = getAxesPermutation;
/** Returns the axes permutation that undoes the original permutation. */
function getUndoAxesPermutation(axes) {
return axes.map(function (axis, i) { return [i, axis]; })
.sort(function (a, b) { return a[1] - b[1]; })
.map(function (x) { return x[0]; });
}
exports.getUndoAxesPermutation = getUndoAxesPermutation;
function getInnerMostAxes(numAxes, rank) {
var res = [];
for (var i = rank - numAxes; i < rank; ++i) {
res.push(i);
}
return res;
}
exports.getInnerMostAxes = getInnerMostAxes;
//# sourceMappingURL=axis_util.js.map