UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

155 lines 8.1 kB
"use strict"; var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) { var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d; if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc); else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r; return c > 3 && r && Object.defineProperty(target, key, r), r; }; Object.defineProperty(exports, "__esModule", { value: true }); var doc_1 = require("../doc"); var util = require("../util"); var operation_1 = require("./operation"); var ops = require("./ops"); var Reduction; (function (Reduction) { Reduction[Reduction["NONE"] = 0] = "NONE"; Reduction[Reduction["MEAN"] = 1] = "MEAN"; Reduction[Reduction["SUM"] = 2] = "SUM"; Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS"; })(Reduction = exports.Reduction || (exports.Reduction = {})); var LossOps = (function () { function LossOps() { } LossOps.computeWeightedLoss = function (losses, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } util.assertArgumentsAreTensors({ losses: losses }, 'computeWeightedLoss'); if (weights != null) { util.assertArgumentsAreTensors({ weights: weights }, 'computeWeightedLoss'); } var weightedLoss = (weights == null) ? losses : losses.mul(weights); if (reduction === Reduction.NONE) { return weightedLoss; } if (reduction === Reduction.SUM) { return weightedLoss.sum(); } if (reduction === Reduction.MEAN) { return (weights == null) ? weightedLoss.mean() : weightedLoss.sum().div(weights.sum()); } if (reduction === Reduction.SUM_BY_NONZERO_WEIGHTS) { if (weights == null) { return weightedLoss.sum().div(ops.scalar(losses.size)); } else { var numNonZeros = weights.notEqual(ops.scalar(0)).sum().toFloat(); return weightedLoss.sum().div(numNonZeros); } } throw Error("Unknown reduction: " + reduction); }; LossOps.absoluteDifference = function (labels, predictions, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } util.assertArgumentsAreTensors({ labels: labels, predictions: predictions }, 'absoluteDifference'); if (weights != null) { util.assertArgumentsAreTensors({ weights: weights }, 'absoluteDifference'); } util.assertShapesMatch(labels.shape, predictions.shape, 'Error in absoluteDifference: '); var losses = labels.sub(predictions).abs(); return LossOps.computeWeightedLoss(losses, weights, reduction); }; LossOps.meanSquaredError = function (labels, predictions, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } util.assertArgumentsAreTensors({ labels: labels, predictions: predictions }, 'meanSquaredError'); if (weights != null) { util.assertArgumentsAreTensors({ weights: weights }, 'meanSquaredError'); } util.assertShapesMatch(labels.shape, predictions.shape, 'Error in meanSquaredError: '); var losses = labels.squaredDifference(predictions); return LossOps.computeWeightedLoss(losses, weights, reduction); }; LossOps.cosineDistance = function (labels, predictions, axis, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } util.assertArgumentsAreTensors({ labels: labels, predictions: predictions }, 'cosineDistance'); if (weights != null) { util.assertArgumentsAreTensors({ weights: weights }, 'cosineDistance'); } util.assertShapesMatch(labels.shape, predictions.shape, 'Error in cosineDistance: '); var one = ops.scalar(1); var losses = one.sub(labels.mul(predictions).sum(axis, true)); return LossOps.computeWeightedLoss(losses, weights, reduction); }; LossOps.hingeLoss = function (labels, predictions, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } util.assertArgumentsAreTensors({ labels: labels, predictions: predictions }, 'hingeLoss'); if (weights != null) { util.assertArgumentsAreTensors({ weights: weights }, 'hingeLoss'); } util.assertShapesMatch(labels.shape, predictions.shape, 'Error in hingeLoss: '); var one = ops.scalar(1); labels = ops.scalar(2).mul(labels).sub(one); var losses = one.sub(labels.mul(predictions)).relu(); return LossOps.computeWeightedLoss(losses, weights, reduction); }; LossOps.logLoss = function (labels, predictions, weights, epsilon, reduction) { if (epsilon === void 0) { epsilon = 1e-7; } if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } util.assertArgumentsAreTensors({ labels: labels, predictions: predictions }, 'logLoss'); if (weights != null) { util.assertArgumentsAreTensors({ weights: weights }, 'logLoss'); } util.assertShapesMatch(labels.shape, predictions.shape, 'Error in logLoss: '); var one = ops.scalar(1); var epsilonScalar = ops.scalar(epsilon); var losses = labels.mul(predictions.add(epsilonScalar).log()) .neg() .sub(one.sub(labels).mul(one.sub(predictions).add(epsilonScalar).log())); return LossOps.computeWeightedLoss(losses, weights, reduction); }; LossOps.huberLoss = function (labels, predictions, weights, delta, reduction) { if (delta === void 0) { delta = 1.0; } if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } util.assertArgumentsAreTensors({ labels: labels, predictions: predictions }, 'huberLoss'); if (weights != null) { util.assertArgumentsAreTensors({ weights: weights }, 'huberLoss'); } util.assertShapesMatch(labels.shape, predictions.shape, 'Error in huberLoss: '); var deltaScalar = ops.scalar(delta); var error = predictions.sub(labels).abs(); var quadratic = ops.minimum(error, deltaScalar); var linear = error.sub(quadratic); var losses = ops.scalar(0.5).mul(quadratic.square()).add(deltaScalar.mul(linear)); return LossOps.computeWeightedLoss(losses, weights, reduction); }; __decorate([ doc_1.doc({ heading: 'Training', subheading: 'Losses', namespace: 'losses' }), operation_1.operation ], LossOps, "computeWeightedLoss", null); __decorate([ doc_1.doc({ heading: 'Training', subheading: 'Losses', namespace: 'losses' }), operation_1.operation ], LossOps, "absoluteDifference", null); __decorate([ doc_1.doc({ heading: 'Training', subheading: 'Losses', namespace: 'losses' }), operation_1.operation ], LossOps, "meanSquaredError", null); __decorate([ doc_1.doc({ heading: 'Training', subheading: 'Losses', namespace: 'losses' }), operation_1.operation ], LossOps, "cosineDistance", null); __decorate([ doc_1.doc({ heading: 'Training', subheading: 'Losses', namespace: 'losses' }), operation_1.operation ], LossOps, "hingeLoss", null); __decorate([ doc_1.doc({ heading: 'Training', subheading: 'Losses', namespace: 'losses' }), operation_1.operation ], LossOps, "logLoss", null); __decorate([ doc_1.doc({ heading: 'Training', subheading: 'Losses', namespace: 'losses' }), operation_1.operation ], LossOps, "huberLoss", null); return LossOps; }()); exports.LossOps = LossOps; //# sourceMappingURL=loss_ops.js.map