@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
241 lines • 10.7 kB
JavaScript
;
var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) {
var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d;
if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc);
else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;
return c > 3 && r && Object.defineProperty(target, key, r), r;
};
Object.defineProperty(exports, "__esModule", { value: true });
var doc_1 = require("../doc");
var environment_1 = require("../environment");
var globals_1 = require("../globals");
var util = require("../util");
var axis_util = require("./axis_util");
var operation_1 = require("./operation");
var ops = require("./ops");
var ReductionOps = (function () {
function ReductionOps() {
}
ReductionOps.logSumExp = function (x, axis, keepDims) {
if (axis === void 0) { axis = null; }
if (keepDims === void 0) { keepDims = false; }
util.assertArgumentsAreTensors({ x: x }, 'logSumExp');
var axes = axis_util.parseAxisParam(axis, x.shape);
var xMax = x.max(axes, true);
var a = x.sub(xMax);
var b = a.exp();
var c = b.sum(axes);
var d = c.log();
var res = xMax.reshape(d.shape).add(d);
if (keepDims) {
var newShape = axis_util.expandShapeToKeepDim(res.shape, axes);
return res.reshape(newShape);
}
return res;
};
ReductionOps.sum = function (x, axis, keepDims) {
if (axis === void 0) { axis = null; }
if (keepDims === void 0) { keepDims = false; }
util.assertArgumentsAreTensors({ x: x }, 'sum');
if (x.dtype === 'bool') {
x = x.toInt();
}
var axes = axis_util.parseAxisParam(axis, x.shape);
var customOp = globals_1.customGrad(function (x) {
var permutation = axis_util.getAxesPermutation(axes, x.rank);
var reductionAxes = axes;
var permutedX = x;
if (permutation != null) {
permutedX = x.transpose(permutation);
reductionAxes =
axis_util.getInnerMostAxes(reductionAxes.length, x.rank);
}
var value = environment_1.ENV.engine.runKernel(function (backend) { return backend.sum(permutedX, reductionAxes); }, { permutedX: permutedX });
if (keepDims) {
var newShape = axis_util.expandShapeToKeepDim(value.shape, axes);
value = value.reshape(newShape);
}
var gradFunc = function (dy) {
var expandedDyShape = x.shape.slice();
axes.forEach(function (axis) {
expandedDyShape[axis] = 1;
});
var expandedDy = dy.reshape(expandedDyShape);
var derX = expandedDy.mul(ops.ones(x.shape, 'float32'));
return derX;
};
return { value: value, gradFunc: gradFunc };
});
return customOp(x);
};
ReductionOps.mean = function (x, axis, keepDims) {
if (axis === void 0) { axis = null; }
if (keepDims === void 0) { keepDims = false; }
util.assertArgumentsAreTensors({ x: x }, 'mean');
var axes = axis_util.parseAxisParam(axis, x.shape);
var shapes = axis_util.computeOutAndReduceShapes(x.shape, axes);
var reduceShape = shapes[1];
var reduceSize = util.sizeFromShape(reduceShape);
var customOp = globals_1.customGrad(function (x) {
var reduceSizeScalar = ops.scalar(reduceSize);
var xReduce = reduceSizeScalar.dtype === x.dtype ?
x :
x.cast(reduceSizeScalar.dtype);
var res = xReduce.div(reduceSizeScalar);
var value = res.sum(axis, keepDims);
var gradFunc = function (dy) {
var expandedDyShape = x.shape.slice();
axes.forEach(function (axis) {
expandedDyShape[axis] = 1;
});
var expandedDy = dy.reshape(expandedDyShape);
var derX = expandedDy.mul(ops.ones(x.shape, 'float32')).div(reduceSizeScalar);
return derX;
};
return { value: value, gradFunc: gradFunc };
});
return customOp(x);
};
ReductionOps.min = function (x, axis, keepDims) {
if (axis === void 0) { axis = null; }
if (keepDims === void 0) { keepDims = false; }
util.assertArgumentsAreTensors({ x: x }, 'min');
var origAxes = axis_util.parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = axis_util.getAxesPermutation(axes, x.rank);
if (permutedAxes != null) {
x = x.transpose(permutedAxes);
axes = axis_util.getInnerMostAxes(axes.length, x.rank);
}
var res = environment_1.ENV.engine.runKernel(function (backend) { return backend.min(x, axes); }, { x: x });
if (keepDims) {
var newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes);
return res.reshape(newShape);
}
return res;
};
ReductionOps.max = function (x, axis, keepDims) {
if (axis === void 0) { axis = null; }
if (keepDims === void 0) { keepDims = false; }
util.assertArgumentsAreTensors({ x: x }, 'max');
var origAxes = axis_util.parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = axis_util.getAxesPermutation(axes, x.rank);
if (permutedAxes != null) {
x = x.transpose(permutedAxes);
axes = axis_util.getInnerMostAxes(axes.length, x.rank);
}
var res = environment_1.ENV.engine.runKernel(function (backend) { return backend.max(x, axes); }, { x: x });
if (keepDims) {
var newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes);
return res.reshape(newShape);
}
return res;
};
ReductionOps.argMin = function (x, axis) {
if (axis === void 0) { axis = 0; }
util.assertArgumentsAreTensors({ x: x }, 'argMin');
if (axis == null) {
axis = 0;
}
var axes = axis_util.parseAxisParam(axis, x.shape);
var permutedAxes = axis_util.getAxesPermutation(axes, x.rank);
if (permutedAxes != null) {
x = x.transpose(permutedAxes);
axes = axis_util.getInnerMostAxes(axes.length, x.rank);
}
return environment_1.ENV.engine.runKernel(function (backend) { return backend.argMin(x, axes[0]); }, { x: x });
};
ReductionOps.argMax = function (x, axis) {
if (axis === void 0) { axis = 0; }
util.assertArgumentsAreTensors({ x: x }, 'argMax');
if (axis == null) {
axis = 0;
}
var axes = axis_util.parseAxisParam(axis, x.shape);
var permutedAxes = axis_util.getAxesPermutation(axes, x.rank);
if (permutedAxes != null) {
x = x.transpose(permutedAxes);
axes = axis_util.getInnerMostAxes(axes.length, x.rank);
}
return environment_1.ENV.engine.runKernel(function (backend) { return backend.argMax(x, axes[0]); }, { x: x });
};
ReductionOps.moments = function (x, axis, keepDims) {
if (axis === void 0) { axis = null; }
if (keepDims === void 0) { keepDims = false; }
util.assertArgumentsAreTensors({ x: x }, 'moments');
var axes = axis_util.parseAxisParam(axis, x.shape);
var mean = x.mean(axes, keepDims);
var keepDimsShape = mean.shape;
if (!keepDims) {
keepDimsShape = axis_util.expandShapeToKeepDim(mean.shape, axes);
}
var devSquared = x.toFloat().sub(mean.reshape(keepDimsShape)).square();
var variance = devSquared.mean(axes, keepDims);
return { mean: mean, variance: variance };
};
ReductionOps.unsortedSegmentSum = function (x, segmentIds, numSegments, axis) {
if (axis === void 0) { axis = 0; }
util.assertArgumentsAreTensors({ x: x, segmentIds: segmentIds }, 'unsortedSegmentSum');
util.assert(segmentIds.dtype === 'int32', 'Segment Ids must be of dtype `int32`');
axis = axis_util.parseAxisParam(axis, x.shape)[0];
var res = [];
var dim = segmentIds.shape[0];
var newShape = [];
for (var i = 0; i < x.shape.length; i++) {
if (i === axis) {
newShape.push(dim);
}
else {
newShape.push(1);
}
}
var reshapedSegmentIds = ops.reshape(segmentIds, newShape);
for (var i = 0; i < numSegments; i++) {
var segmentId = ops.scalar(i, 'int32');
var mask = ops.equal(segmentId, reshapedSegmentIds).asType('float32');
var sum = mask.mul(x).sum(axis);
res.push(sum);
}
return ops.stack(res, axis);
};
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Reduction' }),
operation_1.operation
], ReductionOps, "logSumExp", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Reduction' }),
operation_1.operation
], ReductionOps, "sum", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Reduction' }),
operation_1.operation
], ReductionOps, "mean", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Reduction' }),
operation_1.operation
], ReductionOps, "min", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Reduction' }),
operation_1.operation
], ReductionOps, "max", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Reduction' }),
operation_1.operation
], ReductionOps, "argMin", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Reduction' }),
operation_1.operation
], ReductionOps, "argMax", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Normalization' }),
operation_1.operation
], ReductionOps, "moments", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Reduction' }),
operation_1.operation
], ReductionOps, "unsortedSegmentSum", null);
return ReductionOps;
}());
exports.ReductionOps = ReductionOps;
//# sourceMappingURL=reduction_ops.js.map