@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
266 lines • 15.4 kB
JavaScript
;
var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) {
var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d;
if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc);
else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;
return c > 3 && r && Object.defineProperty(target, key, r), r;
};
Object.defineProperty(exports, "__esModule", { value: true });
var doc_1 = require("../doc");
var environment_1 = require("../environment");
var util = require("../util");
var conv_util = require("./conv_util");
var operation_1 = require("./operation");
var ConvOps = (function () {
function ConvOps() {
}
ConvOps.conv1d = function (x, filter, stride, pad, dataFormat, dilation, dimRoundingMode) {
if (dataFormat === void 0) { dataFormat = 'NWC'; }
if (dilation === void 0) { dilation = 1; }
util.assertArgumentsAreTensors({ x: x, filter: filter }, 'conv1d');
var x3D = x;
var reshapedTo3D = false;
if (x.rank === 2) {
reshapedTo3D = true;
x3D = x.as3D(1, x.shape[0], x.shape[1]);
}
util.assert(x3D.rank === 3, "Error in conv1d: input must be rank 3, but got rank " + x3D.rank + ".");
util.assert(filter.rank === 3, "Error in conv1d: filter must be rank 3, but got rank " +
(filter.rank + "."));
if (dimRoundingMode != null) {
util.assert(util.isInt(pad), "Error in conv1d: pad must be an integer when using, " +
("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."));
}
util.assert(x3D.shape[2] === filter.shape[1], "Error in conv1d: depth of input (" + x3D.shape[2] + ") must match " +
("input depth for filter " + filter.shape[1] + "."));
util.assert(eitherStridesOrDilationsAreOne(stride, dilation), 'Error in conv1D: Either stride or dilation must be 1. ' +
("Got stride " + stride + " and dilation '" + dilation + "'"));
util.assert(dataFormat === 'NWC', "Error in conv1d: got dataFormat of " + dataFormat + " but only NWC is currently supported.");
var filter4D = filter.as4D(1, filter.shape[0], filter.shape[1], filter.shape[2]);
var input4D = x3D.as4D(x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]);
var strides = [1, stride];
var dilations = [1, dilation];
var conv2dDataFormat = 'NHWC';
var res = ConvOps.conv2d(input4D, filter4D, strides, pad, conv2dDataFormat, dilations, dimRoundingMode);
if (reshapedTo3D) {
return res.as2D(res.shape[2], res.shape[3]);
}
return res.as3D(res.shape[0], res.shape[2], res.shape[3]);
};
ConvOps.conv2d = function (x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
if (dataFormat === void 0) { dataFormat = 'NHWC'; }
if (dilations === void 0) { dilations = [1, 1]; }
util.assertArgumentsAreTensors({ x: x, filter: filter }, 'conv2d');
var x4D = x;
var reshapedTo4D = false;
if (x.rank === 3) {
reshapedTo4D = true;
x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]);
}
util.assert(x4D.rank === 4, "Error in conv2d: input must be rank 4, but got rank " + x4D.rank + ".");
util.assert(filter.rank === 4, "Error in conv2d: filter must be rank 4, but got rank " +
(filter.rank + "."));
if (dimRoundingMode != null) {
util.assert(util.isInt(pad), "Error in conv2d: pad must be an integer when using, " +
("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."));
}
util.assert(x4D.shape[3] === filter.shape[2], "Error in conv2d: depth of input (" + x4D.shape[3] + ") must match " +
("input depth for filter " + filter.shape[2] + "."));
util.assert(eitherStridesOrDilationsAreOne(strides, dilations), 'Error in conv2D: Either strides or dilations must be 1. ' +
("Got strides " + strides + " and dilations '" + dilations + "'"));
util.assert(dataFormat === 'NHWC', "Error in conv2d: got dataFormat of " + dataFormat + " but only NHWC is currently supported.");
var convInfo = conv_util.computeConv2DInfo(x4D.shape, filter.shape, strides, dilations, pad, dimRoundingMode);
var grad = function (dy) {
util.assert(tupleValuesAreOne(dilations), 'Error in gradient of conv2D: dilation rates greater than 1 are not' +
("yet supported in gradients. Got dilations '" + dilations + "'"));
return {
x: function () { return ConvOps.conv2dDerInput(x4D.shape, dy, filter, strides, pad); },
filter: function () {
return ConvOps.conv2dDerFilter(x4D, dy, filter.shape, strides, pad);
}
};
};
var res = environment_1.ENV.engine.runKernel(function (backend) { return backend.conv2d(x4D, filter, convInfo); }, { x: x4D, filter: filter }, grad);
if (reshapedTo4D) {
return res.as3D(res.shape[1], res.shape[2], res.shape[3]);
}
return res;
};
ConvOps.conv2dDerInput = function (xShape, dy, filter, strides, pad, dimRoundingMode) {
util.assertArgumentsAreTensors({ dy: dy, filter: filter }, 'conv2dDerInput');
util.assert(xShape.length === dy.rank, "Length of inShape " +
("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match"));
var xShape4D = xShape;
var dy4D = dy;
var reshapedTo4D = false;
if (dy.rank === 3) {
reshapedTo4D = true;
dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]);
xShape4D = [1, xShape[0], xShape[1], xShape[2]];
}
var inDepth = xShape4D[3];
var outDepth = dy4D.shape[3];
util.assert(xShape4D.length === 4, "Error in conv2dDerInput: inShape must be length 4, but got length " +
(xShape4D.length + "."));
util.assert(dy4D.rank === 4, "Error in conv2dDerInput: dy must be rank 4, but got " +
("rank " + dy4D.rank));
util.assert(filter.rank === 4, "Error in conv2dDerInput: filter must be rank 4, but got " +
("rank " + filter.rank));
util.assert(inDepth === filter.shape[2], "Error in conv2dDerInput: depth of input (" + inDepth + ") must " +
("match input depth for filter " + filter.shape[2] + "."));
util.assert(outDepth === filter.shape[3], "Error in conv2dDerInput: depth of output (" + outDepth + ") must " +
("match output depth for filter " + filter.shape[3] + "."));
if (dimRoundingMode != null) {
util.assert(util.isInt(pad), "Error in conv2dDerInput: pad must be an integer when using, " +
("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."));
}
var dilations = 1;
var convInfo = conv_util.computeConv2DInfo(xShape4D, filter.shape, strides, dilations, pad, dimRoundingMode);
var res = environment_1.ENV.engine.runKernel(function (backend) { return backend.conv2dDerInput(dy4D, filter, convInfo); }, { dy4D: dy4D });
if (reshapedTo4D) {
return res.as3D(res.shape[1], res.shape[2], res.shape[3]);
}
return res;
};
ConvOps.conv2dDerFilter = function (x, dy, filterShape, strides, pad, dimRoundingMode) {
util.assertArgumentsAreTensors({ x: x, dy: dy }, 'conv2dDerFilter');
var x4D = x;
if (x.rank === 3) {
x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]);
}
var dy4D = dy;
if (dy4D.rank === 3) {
dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]);
}
util.assert(x4D.rank === 4, "Error in conv2dDerFilter: input must be rank 4, but got shape " +
(x4D.shape + "."));
util.assert(dy4D.rank === 4, "Error in conv2dDerFilter: dy must be rank 4, but got shape " +
(dy4D.shape + "."));
util.assert(filterShape.length === 4, "Error in conv2dDerFilter: filterShape must be length 4, but got " +
(filterShape + "."));
util.assert(x4D.shape[3] === filterShape[2], "Error in conv2dDerFilter: depth of input " + x4D.shape[3] + ") must " +
("match input depth in filter (" + filterShape[2] + "."));
util.assert(dy4D.shape[3] === filterShape[3], "Error in conv2dDerFilter: depth of dy (" + dy4D.shape[3] + ") must " +
("match output depth for filter (" + filterShape[3] + ")."));
if (dimRoundingMode != null) {
util.assert(util.isInt(pad), "Error in conv2dDerFilter: pad must be an integer when using, " +
("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."));
}
var dilations = 1;
var convInfo = conv_util.computeConv2DInfo(x4D.shape, filterShape, strides, dilations, pad, dimRoundingMode);
return environment_1.ENV.engine.runKernel(function (backend) { return backend.conv2dDerFilter(x4D, dy4D, convInfo); }, { x4D: x4D, dy4D: dy4D });
};
ConvOps.conv2dTranspose = function (x, filter, outputShape, strides, pad, dimRoundingMode) {
util.assertArgumentsAreTensors({ x: x, filter: filter }, 'conv2dTranspose');
return ConvOps.conv2dDerInput(outputShape, x, filter, strides, pad, dimRoundingMode);
};
ConvOps.depthwiseConv2d = function (x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
if (dataFormat === void 0) { dataFormat = 'NHWC'; }
if (dilations === void 0) { dilations = [1, 1]; }
util.assertArgumentsAreTensors({ x: x, filter: filter }, 'depthwiseConv2d');
var x4D = x;
var reshapedTo4D = false;
if (x.rank === 3) {
reshapedTo4D = true;
x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]);
}
util.assert(x4D.rank === 4, "Error in depthwiseConv2D: input must be rank 4, but got " +
("rank " + x4D.rank + "."));
util.assert(filter.rank === 4, "Error in depthwiseConv2D: filter must be rank 4, but got rank " +
(filter.rank + "."));
util.assert(x4D.shape[3] === filter.shape[2], "Error in depthwiseConv2D: number of input channels " +
("(" + x4D.shape[3] + ") must match the inChannels dimension in ") +
("filter " + filter.shape[2] + "."));
if (dilations == null) {
dilations = [1, 1];
}
util.assert(eitherStridesOrDilationsAreOne(strides, dilations), 'Error in depthwiseConv2d: Either strides or dilations must be 1. ' +
("Got strides " + strides + " and dilations '" + dilations + "'"));
if (dimRoundingMode != null) {
util.assert(util.isInt(pad), "Error in depthwiseConv2D: pad must be an integer when using, " +
("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."));
}
var convInfo = conv_util.computeConv2DInfo(x4D.shape, filter.shape, strides, dilations, pad, dimRoundingMode, true);
var res = environment_1.ENV.engine.runKernel(function (backend) { return backend.depthwiseConv2D(x4D, filter, convInfo); }, { x4D: x4D, filter: filter });
if (reshapedTo4D) {
return res.as3D(res.shape[1], res.shape[2], res.shape[3]);
}
return res;
};
ConvOps.separableConv2d = function (x, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) {
if (dilation === void 0) { dilation = [1, 1]; }
if (dataFormat === void 0) { dataFormat = 'NHWC'; }
util.assertArgumentsAreTensors({ x: x, depthwiseFilter: depthwiseFilter, pointwiseFilter: pointwiseFilter }, 'separableConv2d');
var x4D = x;
var reshapedTo4D = false;
if (x.rank === 3) {
reshapedTo4D = true;
x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]);
}
if (dataFormat === 'NCHW') {
throw new Error('separableConv2d currently does not support dataFormat NCHW; only ' +
'NHWC is supported');
}
util.assert(x4D.rank === 4, "Error in separableConv2d: input must be rank 4, but got " +
("rank " + x4D.rank + "."));
util.assert(depthwiseFilter.rank === 4, "Error in separableConv2d: depthwise filter must be rank 4, but got " +
("rank " + depthwiseFilter.rank + "."));
util.assert(pointwiseFilter.rank === 4, "Error in separableConv2d: pointwise filter must be rank 4, but got " +
("rank " + depthwiseFilter.rank + "."));
util.assert(pointwiseFilter.shape[0] === 1, "Error in separableConv2d: the first dimension of pointwise filter " +
(" must be 1, but got " + pointwiseFilter.shape[0] + "."));
util.assert(pointwiseFilter.shape[1] === 1, "Error in separableConv2d: the second dimension of pointwise filter " +
(" must be 1, but got " + pointwiseFilter.shape[1] + "."));
var inChannels = depthwiseFilter.shape[2];
var channelMultiplier = depthwiseFilter.shape[3];
util.assert(pointwiseFilter.shape[2] === inChannels * channelMultiplier, "Error in separableConv2d: the third dimension of pointwise filter " +
("must be " + inChannels * channelMultiplier + ", ") +
("but got " + pointwiseFilter.shape[2] + "."));
var depthwise = ConvOps.depthwiseConv2d(x4D, depthwiseFilter, strides, pad, dataFormat, dilation);
var pointwiseStride = 1;
var res = ConvOps.conv2d(depthwise, pointwiseFilter, pointwiseStride, 'valid', dataFormat);
if (reshapedTo4D) {
return res.as3D(res.shape[1], res.shape[2], res.shape[3]);
}
return res;
};
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Convolution' }),
operation_1.operation
], ConvOps, "conv1d", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Convolution' }),
operation_1.operation
], ConvOps, "conv2d", null);
__decorate([
operation_1.operation
], ConvOps, "conv2dDerInput", null);
__decorate([
operation_1.operation
], ConvOps, "conv2dDerFilter", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Convolution' }),
operation_1.operation
], ConvOps, "conv2dTranspose", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Convolution' }),
operation_1.operation
], ConvOps, "depthwiseConv2d", null);
__decorate([
doc_1.doc({ heading: 'Operations', subheading: 'Convolution' }),
operation_1.operation
], ConvOps, "separableConv2d", null);
return ConvOps;
}());
exports.ConvOps = ConvOps;
function parseTupleParam(param) {
return typeof param === 'number' ? [param, param] : param;
}
function tupleValuesAreOne(param) {
var _a = parseTupleParam(param), dimA = _a[0], dimB = _a[1];
return dimA === 1 && dimB === 1;
}
function eitherStridesOrDilationsAreOne(strides, dilations) {
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
}
//# sourceMappingURL=conv.js.map