UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

349 lines (348 loc) 15 kB
"use strict"; var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) { var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d; if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc); else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r; return c > 3 && r && Object.defineProperty(target, key, r), r; }; Object.defineProperty(exports, "__esModule", { value: true }); var doc_1 = require("../doc"); var environment_1 = require("../environment"); var types_1 = require("../types"); var util = require("../util"); var broadcast_util = require("./broadcast_util"); var operation_1 = require("./operation"); var ops_1 = require("./ops"); var BinaryOps = (function () { function BinaryOps() { } BinaryOps.add = function (a, b) { util.assertArgumentsAreTensors({ a: a, b: b }, 'add'); util.assertTypesMatch(a, b); var outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); var der = function (dy) { var derA = function () { var res = dy; var reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape(a.shape); }; var derB = function () { var res = dy; var reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape(b.shape); }; return { a: derA, b: derB }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.add(a, b); }, { a: a, b: b }, der); }; BinaryOps.addStrict = function (a, b) { util.assertShapesMatch(a.shape, b.shape, 'Error in addStrict: '); return a.add(b); }; BinaryOps.sub = function (a, b) { util.assertArgumentsAreTensors({ a: a, b: b }, 'sub'); util.assertTypesMatch(a, b); var outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); var der = function (dy) { var derA = function () { var res = dy; var reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape(a.shape); }; var derB = function () { var res = dy; var reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.neg().reshape(b.shape); }; return { a: derA, b: derB }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.subtract(a, b); }, { a: a, b: b }, der); }; BinaryOps.subStrict = function (a, b) { util.assertShapesMatch(a.shape, b.shape, 'Error in subStrict: '); return a.sub(b); }; BinaryOps.pow = function (base, exp) { util.assertArgumentsAreTensors({ base: base, exp: exp }, 'pow'); var outShape = broadcast_util.assertAndGetBroadcastShape(base.shape, exp.shape); base = base.cast(types_1.upcastType(base.dtype, exp.dtype)); exp = exp.cast(types_1.upcastType(base.dtype, exp.dtype)); var grad = function (dy, saved) { var y = saved[0]; var derBase = function () { var res = dy.mul(exp.toFloat().mul(y.div(base))); var reduceAxes = broadcast_util.getReductionAxes(base.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape(base.shape); }; var derExp = function () { var res = dy.mul(y.mul(base.log()).toFloat()); var reduceAxes = broadcast_util.getReductionAxes(exp.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape(exp.shape); }; return { base: derBase, exp: derExp }; }; return environment_1.ENV.engine.runKernel(function (backend, save) { return save(backend.pow(base, exp)); }, { base: base, exp: exp }, grad); }; BinaryOps.powStrict = function (base, exp) { util.assertShapesMatch(base.shape, exp.shape, 'Error in powStrict: '); return base.pow(exp); }; BinaryOps.mul = function (a, b) { util.assertArgumentsAreTensors({ a: a, b: b }, 'mul'); util.assertTypesMatch(a, b); var outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); var der = function (dy) { var derA = function () { var res = dy.mul(b.toFloat()); var reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape(a.shape); } return res; }; var derB = function () { var res = dy.mul(a.toFloat()); var reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape(b.shape); } return res; }; return { a: derA, b: derB }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.multiply(a, b); }, { a: a, b: b }, der); }; BinaryOps.mulStrict = function (a, b) { util.assertShapesMatch(a.shape, b.shape, 'Error in multiplyStrict: '); return a.mul(b); }; BinaryOps.div = function (a, b) { util.assertArgumentsAreTensors({ a: a, b: b }, 'div'); util.assertTypesMatch(a, b); var outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); var der = function (dy) { var derA = function () { var res = dy.div(b.toFloat()); var reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape(a.shape); } return res; }; var derB = function () { var res = dy.mul(a.toFloat()); var reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes).reshape(b.shape); } var tmp = b.square(); return res.div(tmp.toFloat()).neg(); }; return { a: derA, b: derB }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.divide(a, b); }, { a: a, b: b }, der); }; BinaryOps.divStrict = function (a, b) { util.assertShapesMatch(a.shape, b.shape, 'Error in divideStrict: '); return a.div(b); }; BinaryOps.mod = function (a, b) { util.assertArgumentsAreTensors({ a: a, b: b }, 'mod'); util.assertTypesMatch(a, b); var outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); var der = function (dy) { var derA = function () { var reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return dy.sum(reduceAxes).reshape(a.shape); } return dy; }; var derB = function () { var res = dy.mul(a.div(b).floor().neg()); var reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape(b.shape); } return res; }; return { a: derA, b: derB }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.mod(a, b); }, { a: a, b: b }, der); }; BinaryOps.modStrict = function (a, b) { util.assertShapesMatch(a.shape, b.shape, 'Error in modStrict: '); return a.mod(b); }; BinaryOps.minimum = function (a, b) { util.assertArgumentsAreTensors({ a: a, b: b }, 'minimum'); util.assertTypesMatch(a, b); if (a.dtype === 'bool') { a = a.toInt(); } if (b.dtype === 'bool') { b = b.toInt(); } broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); var der = function (dy) { var derA = function () { return dy.mul(a.lessEqual(b).toFloat()); }; var derB = function () { return dy.mul(a.greater(b).toFloat()); }; return { a: derA, b: derB }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.minimum(a, b); }, { a: a, b: b }, der); }; BinaryOps.minimumStrict = function (a, b) { util.assertShapesMatch(a.shape, b.shape, 'Error in minimumStrict: '); return a.minimum(b); }; BinaryOps.maximum = function (a, b) { util.assertArgumentsAreTensors({ a: a, b: b }, 'maximum'); util.assertTypesMatch(a, b); if (a.dtype === 'bool') { a = a.toInt(); } if (b.dtype === 'bool') { b = b.toInt(); } broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); var der = function (dy) { var derA = function () { return dy.mul(a.greaterEqual(b).toFloat()); }; var derB = function () { return dy.mul(a.less(b).toFloat()); }; return { a: derA, b: derB }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.maximum(a, b); }, { a: a, b: b }, der); }; BinaryOps.maximumStrict = function (a, b) { util.assertShapesMatch(a.shape, b.shape, 'Error in minimumStrict: '); return a.maximum(b); }; BinaryOps.squaredDifference = function (a, b) { util.assertArgumentsAreTensors({ a: a, b: b }, 'squaredDifference'); util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); var der = function (dy) { var two = ops_1.scalar(2); var derA = function () { return dy.mul(a.sub(b).mul(two)); }; var derB = function () { return dy.mul(b.sub(a).mul(two)); }; return { a: derA, b: derB }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.squaredDifference(a, b); }, { a: a, b: b }, der); }; BinaryOps.squaredDifferenceStrict = function (a, b) { util.assertShapesMatch(a.shape, b.shape, 'Error in squaredDifferenceStrict: '); return a.squaredDifference(b); }; BinaryOps.atan2 = function (a, b) { util.assertArgumentsAreTensors({ a: a, b: b }, 'atan2'); util.assertTypesMatch(a, b); var outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); var der = function (dy) { var derA = function () { var d = BinaryOps.add(ops_1.square(a), ops_1.square(b)); var res = dy.mul(b.div(d)); var reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape(a.shape); }; var derB = function () { var d = BinaryOps.add(ops_1.square(a), ops_1.square(b)); var res = ops_1.neg(dy.mul(a.div(d))); var reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape(b.shape); }; return { a: derA, b: derB }; }; return environment_1.ENV.engine.runKernel(function (backend) { return backend.atan2(a, b); }, { a: a, b: b }, der); }; __decorate([ doc_1.doc({ heading: 'Operations', subheading: 'Arithmetic' }), operation_1.operation ], BinaryOps, "add", null); __decorate([ operation_1.operation ], BinaryOps, "addStrict", null); __decorate([ doc_1.doc({ heading: 'Operations', subheading: 'Arithmetic' }), operation_1.operation ], BinaryOps, "sub", null); __decorate([ operation_1.operation ], BinaryOps, "subStrict", null); __decorate([ doc_1.doc({ heading: 'Operations', subheading: 'Arithmetic' }), operation_1.operation ], BinaryOps, "pow", null); __decorate([ operation_1.operation ], BinaryOps, "powStrict", null); __decorate([ doc_1.doc({ heading: 'Operations', subheading: 'Arithmetic' }), operation_1.operation ], BinaryOps, "mul", null); __decorate([ operation_1.operation ], BinaryOps, "mulStrict", null); __decorate([ doc_1.doc({ heading: 'Operations', subheading: 'Arithmetic' }), operation_1.operation ], BinaryOps, "div", null); __decorate([ operation_1.operation ], BinaryOps, "divStrict", null); __decorate([ doc_1.doc({ heading: 'Operations', subheading: 'Arithmetic' }), operation_1.operation ], BinaryOps, "mod", null); __decorate([ operation_1.operation ], BinaryOps, "modStrict", null); __decorate([ doc_1.doc({ heading: 'Operations', subheading: 'Arithmetic' }), operation_1.operation ], BinaryOps, "minimum", null); __decorate([ operation_1.operation ], BinaryOps, "minimumStrict", null); __decorate([ doc_1.doc({ heading: 'Operations', subheading: 'Arithmetic' }), operation_1.operation ], BinaryOps, "maximum", null); __decorate([ operation_1.operation ], BinaryOps, "maximumStrict", null); __decorate([ doc_1.doc({ heading: 'Operations', subheading: 'Arithmetic' }), operation_1.operation ], BinaryOps, "squaredDifference", null); __decorate([ operation_1.operation ], BinaryOps, "squaredDifferenceStrict", null); __decorate([ operation_1.operation ], BinaryOps, "atan2", null); return BinaryOps; }()); exports.BinaryOps = BinaryOps;