@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
390 lines • 17.5 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var environment_1 = require("../environment");
var tensor_util_1 = require("../tensor_util");
var tensor_util_env_1 = require("../tensor_util_env");
var types_1 = require("../types");
var util = require("../util");
var broadcast_util = require("./broadcast_util");
var operation_1 = require("./operation");
var tensor_ops_1 = require("./tensor_ops");
var unary_ops_1 = require("./unary_ops");
function add_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'add');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'add');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var derA = function () {
var res = dy;
var reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($a.shape);
};
var derB = function () {
var res = dy;
var reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($b.shape);
};
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(function (backend) { return backend.add($a, $b); }, { $a: $a, $b: $b }, der);
}
function addN_(tensors) {
util.assert(Array.isArray(tensors), function () { return 'The argument passed to tf.addN() must be a list of tensors'; });
util.assert(tensors.length >= 1, function () { return "Must pass at least one tensor to tf.addN(), but got " +
("" + tensors.length); });
var $tensors = tensors.map(function (t, i) { return tensor_util_env_1.convertToTensor(t, "tensors" + i, 'addN'); });
var firstTensor = $tensors[0];
$tensors.forEach(function (t) {
if (t.dtype !== firstTensor.dtype) {
throw new Error('All tensors passed to tf.addN() must have the same dtype');
}
});
$tensors.forEach(function (t) {
if (!util.arraysEqual(t.shape, firstTensor.shape)) {
throw new Error('All tensors passed to tf.addN() must have the same shape');
}
});
var der = function (dy) {
var ders = {};
$tensors.forEach(function (t, i) {
ders[i] = function () { return dy.clone(); };
});
return ders;
};
var inputs = $tensors;
return environment_1.ENV.engine.runKernel(function (backend) { return backend.addN($tensors); }, inputs, der);
}
function addStrict_(a, b) {
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'addStrict');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'addStrict');
util.assertShapesMatch($a.shape, $b.shape, 'Error in addStrict: ');
return $a.add($b);
}
function sub_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'sub');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'sub');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var derA = function () {
var res = dy;
var reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($a.shape);
};
var derB = function () {
var res = dy;
var reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.neg().reshape($b.shape);
};
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(function (backend) { return backend.subtract($a, $b); }, { $a: $a, $b: $b }, der);
}
function subStrict_(a, b) {
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'subStrict');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'subStrict');
util.assertShapesMatch($a.shape, $b.shape, 'Error in subStrict: ');
return $a.sub($b);
}
function pow_(base, exp) {
var $base = tensor_util_env_1.convertToTensor(base, 'base', 'pow');
var $exp = tensor_util_env_1.convertToTensor(exp, 'exp', 'pow');
var outShape = broadcast_util.assertAndGetBroadcastShape($base.shape, $exp.shape);
base = $base.cast(types_1.upcastType($base.dtype, $exp.dtype));
exp = $exp.cast(types_1.upcastType($base.dtype, $exp.dtype));
var grad = function (dy, saved) {
var y = saved[0];
var derBase = function () {
var expFloat = $exp.toFloat();
var res = dy.mul(expFloat.mul($base.pow(expFloat.sub(tensor_ops_1.scalar(1)))));
var reduceAxes = broadcast_util.getReductionAxes($base.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($base.shape);
};
var derExp = function () {
var condition = $base.greater(0);
var logBase = $base.log().where(condition, tensor_ops_1.zerosLike($base));
var res = dy.mul(y.mul(logBase));
var reduceAxes = broadcast_util.getReductionAxes($exp.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($exp.shape);
};
return { $base: derBase, $exp: derExp };
};
return environment_1.ENV.engine.runKernel(function (backend, save) { return save(backend.pow($base, $exp)); }, { $base: $base, $exp: $exp }, grad);
}
function powStrict_(base, exp) {
util.assertShapesMatch(base.shape, exp.shape, 'Error in powStrict: ');
return base.pow(exp);
}
function mul_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'mul');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'mul');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var derA = function () {
var res = dy.mul($b.toFloat());
var reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);
if (reduceAxes.length > 0) {
return res.sum(reduceAxes).reshape($a.shape);
}
return res;
};
var derB = function () {
var res = dy.mul($a.toFloat());
var reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);
if (reduceAxes.length > 0) {
return res.sum(reduceAxes).reshape($b.shape);
}
return res;
};
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(function (backend) { return backend.multiply($a, $b); }, { $a: $a, $b: $b }, der);
}
function mulStrict_(a, b) {
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'mul');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'mul');
util.assertShapesMatch($a.shape, $b.shape, 'Error in multiplyStrict: ');
return $a.mul($b);
}
function div_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'div');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'div');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var forwardFunc;
if ($a.dtype === 'int32' && $b.dtype === 'int32') {
return exports.floorDiv($a, $b);
}
else {
forwardFunc = function (backend) { return backend.realDivide($a, $b); };
}
var outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var derA = function () {
var res = dy.div($b.toFloat());
var reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);
if (reduceAxes.length > 0) {
return res.sum(reduceAxes).reshape($a.shape);
}
return res;
};
var derB = function () {
var res = dy.mul($a.toFloat());
var reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes).reshape($b.shape);
}
var tmp = $b.square();
return res.div(tmp.toFloat()).neg();
};
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(forwardFunc, { $a: $a, $b: $b }, der);
}
function floorDiv_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'floorDiv');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'floorDiv');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var forwardFunc = function (backend) { return backend.floorDiv($a, $b); };
var outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var derA = function () {
var res = dy.div($b.toFloat());
var reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);
if (reduceAxes.length > 0) {
return res.sum(reduceAxes).reshape($a.shape);
}
return res;
};
var derB = function () {
var res = dy.mul($a.toFloat());
var reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes).reshape($b.shape);
}
var tmp = $b.square();
return res.div(tmp.toFloat()).neg();
};
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(forwardFunc, { $a: $a, $b: $b }, der);
}
function divStrict_(a, b) {
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'div');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'div');
util.assertShapesMatch($a.shape, $b.shape, 'Error in divideStrict: ');
return $a.div($b);
}
function mod_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'mod');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'mod');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var derA = function () {
var reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);
if (reduceAxes.length > 0) {
return dy.sum(reduceAxes).reshape($a.shape);
}
return dy;
};
var derB = function () {
var res = dy.mul($a.div($b).floor().neg());
var reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);
if (reduceAxes.length > 0) {
return res.sum(reduceAxes).reshape($b.shape);
}
return res;
};
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(function (backend) { return backend.mod($a, $b); }, { $a: $a, $b: $b }, der);
}
function modStrict_(a, b) {
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'modStrict');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'modStrict');
util.assertShapesMatch($a.shape, $b.shape, 'Error in modStrict: ');
return $a.mod($b);
}
function minimum_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'minimum');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'minimum');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
if ($a.dtype === 'bool') {
$a = $a.toInt();
$b = $b.toInt();
}
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var derA = function () { return dy.mul($a.lessEqual($b).toFloat()); };
var derB = function () { return dy.mul($a.greater($b).toFloat()); };
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(function (backend) { return backend.minimum($a, $b); }, { $a: $a, $b: $b }, der);
}
function minimumStrict_(a, b) {
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'minimumStrict');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'minimumStrict');
util.assertShapesMatch($a.shape, $b.shape, 'Error in minimumStrict: ');
return $a.minimum($b);
}
function maximum_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'maximum');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'maximum');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
if ($a.dtype === 'bool') {
$a = $a.toInt();
$b = $b.toInt();
}
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var derA = function () { return dy.mul($a.greaterEqual($b).toFloat()); };
var derB = function () { return dy.mul($a.less($b).toFloat()); };
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(function (backend) { return backend.maximum($a, $b); }, { $a: $a, $b: $b }, der);
}
function maximumStrict_(a, b) {
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'maximumStrict');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'maximumStrict');
util.assertShapesMatch($a.shape, $b.shape, 'Error in maximumStrict: ');
return $a.maximum($b);
}
function squaredDifference_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'squaredDifference');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'squaredDifference');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var two = tensor_ops_1.scalar(2);
var derA = function () { return dy.mul($a.sub($b).mul(two)); };
var derB = function () { return dy.mul($b.sub($a).mul(two)); };
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(function (backend) { return backend.squaredDifference($a, $b); }, { $a: $a, $b: $b }, der);
}
function squaredDifferenceStrict_(a, b) {
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'squaredDifferenceStrict');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'squaredDifferenceStrict');
util.assertShapesMatch($a.shape, $b.shape, 'Error in squaredDifferenceStrict: ');
return $a.squaredDifference($b);
}
function atan2_(a, b) {
var _a;
var $a = tensor_util_env_1.convertToTensor(a, 'a', 'atan2');
var $b = tensor_util_env_1.convertToTensor(b, 'b', 'atan2');
_a = tensor_util_1.makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
var der = function (dy) {
var derA = function () {
var d = exports.add($a.square(), $b.square());
var res = dy.mul($b.div(d));
var reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($a.shape);
};
var derB = function () {
var d = exports.add($a.square(), $b.square());
var res = unary_ops_1.neg(dy.mul($a.div(d)));
var reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($b.shape);
};
return { $a: derA, $b: derB };
};
return environment_1.ENV.engine.runKernel(function (backend) { return backend.atan2($a, $b); }, { $a: $a, $b: $b }, der);
}
exports.add = operation_1.op({ add_: add_ });
exports.addN = operation_1.op({ addN_: addN_ });
exports.addStrict = operation_1.op({ addStrict_: addStrict_ });
exports.atan2 = operation_1.op({ atan2_: atan2_ });
exports.div = operation_1.op({ div_: div_ });
exports.divStrict = operation_1.op({ divStrict_: divStrict_ });
exports.floorDiv = operation_1.op({ floorDiv_: floorDiv_ });
exports.maximum = operation_1.op({ maximum_: maximum_ });
exports.maximumStrict = operation_1.op({ maximumStrict_: maximumStrict_ });
exports.minimum = operation_1.op({ minimum_: minimum_ });
exports.minimumStrict = operation_1.op({ minimumStrict_: minimumStrict_ });
exports.mod = operation_1.op({ mod_: mod_ });
exports.modStrict = operation_1.op({ modStrict_: modStrict_ });
exports.mul = operation_1.op({ mul_: mul_ });
exports.mulStrict = operation_1.op({ mulStrict_: mulStrict_ });
exports.pow = operation_1.op({ pow_: pow_ });
exports.powStrict = operation_1.op({ powStrict_: powStrict_ });
exports.squaredDifference = operation_1.op({ squaredDifference_: squaredDifference_ });
exports.squaredDifferenceStrict = operation_1.op({ squaredDifferenceStrict_: squaredDifferenceStrict_ });
exports.sub = operation_1.op({ sub_: sub_ });
exports.subStrict = operation_1.op({ subStrict_: subStrict_ });
//# sourceMappingURL=binary_ops.js.map