UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

224 lines 12 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var globals_1 = require("../globals"); var tensor_util_env_1 = require("../tensor_util_env"); var util_1 = require("../util"); var axis_util_1 = require("./axis_util"); var binary_ops_1 = require("./binary_ops"); var operation_1 = require("./operation"); var tensor_ops_1 = require("./tensor_ops"); var Reduction; (function (Reduction) { Reduction[Reduction["NONE"] = 0] = "NONE"; Reduction[Reduction["MEAN"] = 1] = "MEAN"; Reduction[Reduction["SUM"] = 2] = "SUM"; Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS"; })(Reduction = exports.Reduction || (exports.Reduction = {})); function computeWeightedLoss_(losses, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } var $losses = tensor_util_env_1.convertToTensor(losses, 'losses', 'computeWeightedLoss'); var $weights = null; if (weights != null) { $weights = tensor_util_env_1.convertToTensor(weights, 'weights', 'computeWeightedLoss'); } var weightedLoss = ($weights == null) ? $losses : $losses.mul($weights); if (reduction === Reduction.NONE) { return weightedLoss; } if (reduction === Reduction.SUM) { return weightedLoss.sum(); } if (reduction === Reduction.MEAN) { if ($weights == null) { return weightedLoss.mean(); } else { var broadcastFactor = util_1.sizeFromShape($losses.shape) / util_1.sizeFromShape($weights.shape); var result = weightedLoss.sum().div($weights.sum()); return broadcastFactor > 1 ? result.div(tensor_ops_1.scalar(broadcastFactor)) : result; } } if (reduction === Reduction.SUM_BY_NONZERO_WEIGHTS) { if ($weights == null) { return weightedLoss.sum().div(tensor_ops_1.scalar($losses.size)); } else { var broadcastedWeights = $weights.mul(tensor_ops_1.ones($losses.shape)); var numNonZeros = broadcastedWeights.notEqual(tensor_ops_1.scalar(0)).sum().toFloat(); return weightedLoss.sum().div(numNonZeros); } } throw Error("Unknown reduction: " + reduction); } function absoluteDifference_(labels, predictions, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } var $labels = tensor_util_env_1.convertToTensor(labels, 'labels', 'absoluteDifference'); var $predictions = tensor_util_env_1.convertToTensor(predictions, 'predictions', 'absoluteDifference'); var $weights = null; if (weights != null) { $weights = tensor_util_env_1.convertToTensor(weights, 'weights', 'absoluteDifference'); } util_1.assertShapesMatch($labels.shape, $predictions.shape, 'Error in absoluteDifference: '); var losses = $labels.sub($predictions).abs(); return exports.computeWeightedLoss(losses, $weights, reduction); } function meanSquaredError_(labels, predictions, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } var $labels = tensor_util_env_1.convertToTensor(labels, 'labels', 'meanSquaredError'); var $predictions = tensor_util_env_1.convertToTensor(predictions, 'predictions', 'meanSquaredError'); var $weights = null; if (weights != null) { $weights = tensor_util_env_1.convertToTensor(weights, 'weights', 'meanSquaredError'); } util_1.assertShapesMatch($labels.shape, $predictions.shape, 'Error in meanSquaredError: '); var losses = $labels.squaredDifference($predictions); return exports.computeWeightedLoss(losses, $weights, reduction); } function cosineDistance_(labels, predictions, axis, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } var $labels = tensor_util_env_1.convertToTensor(labels, 'labels', 'cosineDistance'); var $predictions = tensor_util_env_1.convertToTensor(predictions, 'predictions', 'cosineDistance'); var $weights = null; if (weights != null) { $weights = tensor_util_env_1.convertToTensor(weights, 'weights', 'cosineDistance'); } util_1.assertShapesMatch($labels.shape, $predictions.shape, 'Error in cosineDistance: '); var one = tensor_ops_1.scalar(1); var losses = one.sub($labels.mul($predictions).sum(axis, true)); return exports.computeWeightedLoss(losses, $weights, reduction); } function hingeLoss_(labels, predictions, weights, reduction) { if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } var $labels = tensor_util_env_1.convertToTensor(labels, 'labels', 'hingeLoss'); var $predictions = tensor_util_env_1.convertToTensor(predictions, 'predictions', 'hingeLoss'); var $weights = null; if (weights != null) { $weights = tensor_util_env_1.convertToTensor(weights, 'weights', 'hingeLoss'); } util_1.assertShapesMatch($labels.shape, $predictions.shape, 'Error in hingeLoss: '); var one = tensor_ops_1.scalar(1); $labels = tensor_ops_1.scalar(2).mul($labels).sub(one); var losses = one.sub($labels.mul($predictions)).relu(); return exports.computeWeightedLoss(losses, $weights, reduction); } function logLoss_(labels, predictions, weights, epsilon, reduction) { if (epsilon === void 0) { epsilon = 1e-7; } if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } var $labels = tensor_util_env_1.convertToTensor(labels, 'labels', 'logLoss'); var $predictions = tensor_util_env_1.convertToTensor(predictions, 'predictions', 'logLoss'); var $weights = null; if (weights != null) { $weights = tensor_util_env_1.convertToTensor(weights, 'weights', 'logLoss'); } util_1.assertShapesMatch($labels.shape, $predictions.shape, 'Error in logLoss: '); var one = tensor_ops_1.scalar(1); var epsilonScalar = tensor_ops_1.scalar(epsilon); var losses = $labels.mul($predictions.add(epsilonScalar).log()) .neg() .sub(one.sub($labels).mul(one.sub($predictions).add(epsilonScalar).log())); return exports.computeWeightedLoss(losses, $weights, reduction); } function sigmoidCrossEntropyWithLogits_(labels, logits) { var $labels = tensor_util_env_1.convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits'); var $logits = tensor_util_env_1.convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits'); util_1.assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: '); var maxOutput = $logits.relu(); var outputXTarget = $logits.mul($labels); var sigmoidOutput = $logits.abs().neg().exp().log1p(); return maxOutput.sub(outputXTarget).add(sigmoidOutput); } function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing, reduction) { if (labelSmoothing === void 0) { labelSmoothing = 0; } if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } var $multiClassLabels = tensor_util_env_1.convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy'); var $logits = tensor_util_env_1.convertToTensor(logits, 'logits', 'sigmoidCrossEntropy'); var $weights = null; if (weights != null) { $weights = tensor_util_env_1.convertToTensor(weights, 'weights', 'sigmoidCrossEntropy'); } util_1.assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: '); if (labelSmoothing > 0) { var labelSmoothingScalar = tensor_ops_1.scalar(labelSmoothing); var one = tensor_ops_1.scalar(1); var half = tensor_ops_1.scalar(0.5); $multiClassLabels = $multiClassLabels.mul(one.sub(labelSmoothingScalar)) .add(half.mul(labelSmoothingScalar)); } var losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits); return exports.computeWeightedLoss(losses, $weights, reduction); } function huberLoss_(labels, predictions, weights, delta, reduction) { if (delta === void 0) { delta = 1.0; } if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } var $labels = tensor_util_env_1.convertToTensor(labels, 'labels', 'huberLoss'); var $predictions = tensor_util_env_1.convertToTensor(predictions, 'predictions', 'huberLoss'); var $weights = null; if (weights != null) { $weights = tensor_util_env_1.convertToTensor(weights, 'weights', 'huberLoss'); } util_1.assertShapesMatch($labels.shape, $predictions.shape, 'Error in huberLoss: '); var deltaScalar = tensor_ops_1.scalar(delta); var error = $predictions.sub($labels).abs(); var quadratic = binary_ops_1.minimum(error, deltaScalar); var linear = error.sub(quadratic); var losses = tensor_ops_1.scalar(0.5).mul(quadratic.square()).add(deltaScalar.mul(linear)); return exports.computeWeightedLoss(losses, $weights, reduction); } function softmaxCrossEntropyWithLogits_(labels, logits, dim) { if (dim === void 0) { dim = -1; } if (dim === -1) { dim = logits.rank - 1; } if (dim !== logits.rank - 1) { throw Error("Softmax cross entropy along a non-last dimension is not yet " + ("supported. Labels / logits was rank " + logits.rank + " ") + ("and dim was " + dim)); } var customOp = globals_1.customGrad(function (labels, logits) { var keepDims = true; var lse = logits.logSumExp([dim], keepDims); var logResult = logits.toFloat().sub(lse); var costVector = logResult.mul(labels).neg(); var value = costVector.sum([dim]); var gradFunc = function (dy) { var dyShape = axis_util_1.expandShapeToKeepDim(dy.shape, [dim]); return [ dy.reshape(dyShape).mul(labels.toFloat().sub(logResult.exp())), dy.reshape(dyShape).mul(logResult.exp().sub(labels.toFloat())), ]; }; return { value: value, gradFunc: gradFunc }; }); return customOp(labels, logits); } function softmaxCrossEntropy_(onehotLabels, logits, weights, labelSmoothing, reduction) { if (labelSmoothing === void 0) { labelSmoothing = 0; } if (reduction === void 0) { reduction = Reduction.SUM_BY_NONZERO_WEIGHTS; } var $onehotLabels = tensor_util_env_1.convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy'); var $logits = tensor_util_env_1.convertToTensor(logits, 'logits', 'softmaxCrossEntropy'); var $weights = null; if (weights != null) { $weights = tensor_util_env_1.convertToTensor(weights, 'weights', 'softmaxCrossEntropy'); } util_1.assertShapesMatch($onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: '); if (labelSmoothing > 0) { var labelSmoothingScalar = tensor_ops_1.scalar(labelSmoothing); var one = tensor_ops_1.scalar(1); var numClasses = tensor_ops_1.scalar($onehotLabels.shape[1]); $onehotLabels = $onehotLabels.mul(one.sub(labelSmoothingScalar)) .add(labelSmoothingScalar.div(numClasses)); } var losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits); return exports.computeWeightedLoss(losses, $weights, reduction); } exports.absoluteDifference = operation_1.op({ absoluteDifference_: absoluteDifference_ }); exports.computeWeightedLoss = operation_1.op({ computeWeightedLoss_: computeWeightedLoss_ }); exports.cosineDistance = operation_1.op({ cosineDistance_: cosineDistance_ }); exports.hingeLoss = operation_1.op({ hingeLoss_: hingeLoss_ }); exports.huberLoss = operation_1.op({ huberLoss_: huberLoss_ }); exports.logLoss = operation_1.op({ logLoss_: logLoss_ }); exports.meanSquaredError = operation_1.op({ meanSquaredError_: meanSquaredError_ }); exports.sigmoidCrossEntropy = operation_1.op({ sigmoidCrossEntropy_: sigmoidCrossEntropy_ }); exports.softmaxCrossEntropy = operation_1.op({ softmaxCrossEntropy_: softmaxCrossEntropy_ }); //# sourceMappingURL=loss_ops.js.map