UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

263 lines 10.9 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var environment_1 = require("../environment"); var globals_1 = require("../globals"); var tensor_util_env_1 = require("../tensor_util_env"); var util = require("../util"); var axis_util = require("./axis_util"); var operation_1 = require("./operation"); var tensor_ops_1 = require("./tensor_ops"); function logSumExp_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'logSumExp'); var axes = util.parseAxisParam(axis, $x.shape); var xMax = $x.max(axes, true); var a = $x.sub(xMax); var b = a.exp(); var c = b.sum(axes); var d = c.log(); var res = xMax.reshape(d.shape).add(d); if (keepDims) { var newShape = axis_util.expandShapeToKeepDim(res.shape, axes); return res.reshape(newShape); } return res; } function sum_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'sum'); if ($x.dtype === 'bool') { $x = $x.toInt(); } var axes = util.parseAxisParam(axis, $x.shape); var customOp = globals_1.customGrad(function (x) { var permutation = axis_util.getAxesPermutation(axes, x.rank); var reductionAxes = axes; var permutedX = x; if (permutation != null) { permutedX = x.transpose(permutation); reductionAxes = axis_util.getInnerMostAxes(reductionAxes.length, x.rank); } var value = environment_1.ENV.engine.runKernel(function (backend) { return backend.sum(permutedX, reductionAxes); }, { permutedX: permutedX }); if (keepDims) { var newShape = axis_util.expandShapeToKeepDim(value.shape, axes); value = value.reshape(newShape); } var gradFunc = function (dy) { var expandedDyShape = x.shape.slice(); axes.forEach(function (axis) { expandedDyShape[axis] = 1; }); var expandedDy = dy.reshape(expandedDyShape); var derX = expandedDy.mul(tensor_ops_1.ones(x.shape, 'float32')); return derX; }; return { value: value, gradFunc: gradFunc }; }); return customOp($x); } function prod_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'prod'); if ($x.dtype === 'bool') { $x = $x.toInt(); } var axes = util.parseAxisParam(axis, $x.shape); var permutation = axis_util.getAxesPermutation(axes, $x.rank); var reductionAxes = axes; var permutedX = $x; if (permutation != null) { permutedX = $x.transpose(permutation); reductionAxes = axis_util.getInnerMostAxes(reductionAxes.length, $x.rank); } var value = environment_1.ENV.engine.runKernel(function (backend) { return backend.prod(permutedX, reductionAxes); }, { permutedX: permutedX }); if (keepDims) { var newShape = axis_util.expandShapeToKeepDim(value.shape, axes); value = value.reshape(newShape); } return value; } function mean_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'mean'); var axes = util.parseAxisParam(axis, $x.shape); var shapes = axis_util.computeOutAndReduceShapes($x.shape, axes); var reduceShape = shapes[1]; var reduceSize = util.sizeFromShape(reduceShape); var customOp = globals_1.customGrad(function (x) { var reduceSizeScalar = tensor_ops_1.scalar(reduceSize); var xReduce = reduceSizeScalar.dtype === x.dtype ? x : x.cast(reduceSizeScalar.dtype); var res = xReduce.div(reduceSizeScalar); var value = res.sum(axis, keepDims); var gradFunc = function (dy) { var expandedDyShape = x.shape.slice(); axes.forEach(function (axis) { expandedDyShape[axis] = 1; }); var expandedDy = dy.reshape(expandedDyShape); var derX = expandedDy.mul(tensor_ops_1.ones(x.shape, 'float32')).div(reduceSizeScalar); return derX; }; return { value: value, gradFunc: gradFunc }; }); return customOp($x); } function gradForMinAndMax(dy, saved, xOrig, origAxes, permutedAxes) { var y = saved[0]; if (y.rank < xOrig.rank) { y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)); } if (dy.rank < xOrig.rank) { dy = dy.reshape(axis_util.expandShapeToKeepDim(dy.shape, origAxes)); } return { $x: function () { var dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); return permutedAxes == null ? dx : dx.transpose(permutedAxes); } }; } function min_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'min'); var xOrig = $x; var origAxes = util.parseAxisParam(axis, $x.shape); var axes = origAxes; var permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { $x = $x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, $x.rank); } var grad = function (dy, saved) { return gradForMinAndMax(dy, saved, xOrig, origAxes, permutedAxes); }; var res = environment_1.ENV.engine.runKernel(function (backend, save) { return save(backend.min($x, axes)); }, { $x: $x }, grad); if (keepDims) { var newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); res = res.reshape(newShape); } return res; } function max_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'max'); var xOrig = $x; var origAxes = util.parseAxisParam(axis, $x.shape); var axes = origAxes; var permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { $x = $x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, $x.rank); } var grad = function (dy, saved) { return gradForMinAndMax(dy, saved, xOrig, origAxes, permutedAxes); }; var res = environment_1.ENV.engine.runKernel(function (backend, save) { return save(backend.max($x, axes)); }, { $x: $x }, grad); if (keepDims) { var newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); res = res.reshape(newShape); } return res; } function argMin_(x, axis) { if (axis === void 0) { axis = 0; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'argMin'); if (axis == null) { axis = 0; } var axes = util.parseAxisParam(axis, $x.shape); var permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { $x = $x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, $x.rank); } var grad = function (dy) { return { $x: function () { return tensor_ops_1.zerosLike($x); } }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.argMin($x, axes[0]); }, { $x: $x }, grad); } function argMax_(x, axis) { if (axis === void 0) { axis = 0; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'argMax'); if (axis == null) { axis = 0; } var axes = util.parseAxisParam(axis, $x.shape); var permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { $x = $x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, $x.rank); } var grad = function (dy) { return { $x: function () { return tensor_ops_1.zerosLike($x); } }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.argMax($x, axes[0]); }, { $x: $x }, grad); } function all_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'all', 'bool'); var origAxes = util.parseAxisParam(axis, $x.shape); var axes = origAxes; var permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { $x = $x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, $x.rank); } var res = environment_1.ENV.engine.runKernel(function (backend) { return backend.all($x, axes); }, { $x: $x }); if (keepDims) { var newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); return res.reshape(newShape); } return res; } function any_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'any', 'bool'); var origAxes = util.parseAxisParam(axis, $x.shape); var axes = origAxes; var permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { $x = $x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, $x.rank); } var res = environment_1.ENV.engine.runKernel(function (backend) { return backend.any($x, axes); }, { $x: $x }); if (keepDims) { var newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); return res.reshape(newShape); } return res; } function moments_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } x = tensor_util_env_1.convertToTensor(x, 'x', 'moments'); var axes = util.parseAxisParam(axis, x.shape); var mean = x.mean(axes, keepDims); var keepDimsShape = mean.shape; if (!keepDims) { keepDimsShape = axis_util.expandShapeToKeepDim(mean.shape, axes); } var devSquared = x.toFloat().sub(mean.reshape(keepDimsShape)).square(); var variance = devSquared.mean(axes, keepDims); return { mean: mean, variance: variance }; } exports.all = operation_1.op({ all_: all_ }); exports.any = operation_1.op({ any_: any_ }); exports.argMax = operation_1.op({ argMax_: argMax_ }); exports.argMin = operation_1.op({ argMin_: argMin_ }); exports.logSumExp = operation_1.op({ logSumExp_: logSumExp_ }); exports.max = operation_1.op({ max_: max_ }); exports.mean = operation_1.op({ mean_: mean_ }); exports.min = operation_1.op({ min_: min_ }); exports.moments = operation_1.op({ moments_: moments_ }); exports.sum = operation_1.op({ sum_: sum_ }); exports.prod = operation_1.op({ prod_: prod_ }); //# sourceMappingURL=reduction_ops.js.map