@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
381 lines • 17.1 kB
JavaScript
"use strict";
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
var util = require("../util");
function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) {
if (dataFormat === void 0) { dataFormat = 'channelsLast'; }
var _a = parseTupleParam(filterSize), filterHeight = _a[0], filterWidth = _a[1];
var filterShape;
if (dataFormat === 'channelsLast') {
filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
}
else if (dataFormat === 'channelsFirst') {
filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
}
else {
throw new Error("Unknown dataFormat " + dataFormat);
}
return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
}
exports.computePool2DInfo = computePool2DInfo;
/**
* Computes the information for a forward pass of a pooling3D operation.
*/
function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) {
if (dataFormat === void 0) { dataFormat = 'NDHWC'; }
var _a = parse3TupleParam(filterSize), filterDepth = _a[0], filterHeight = _a[1], filterWidth = _a[2];
var filterShape;
var $dataFormat;
if (dataFormat === 'NDHWC') {
$dataFormat = 'channelsLast';
filterShape =
[filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
}
else if (dataFormat === 'NCDHW') {
$dataFormat = 'channelsFirst';
filterShape =
[filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
}
else {
throw new Error("Unknown dataFormat " + dataFormat);
}
return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
}
exports.computePool3DInfo = computePool3DInfo;
/**
* Computes the information for a forward pass of a convolution/pooling
* operation.
*/
function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise, dataFormat) {
if (depthwise === void 0) { depthwise = false; }
if (dataFormat === void 0) { dataFormat = 'channelsLast'; }
var _a = [-1, -1, -1, -1], batchSize = _a[0], inHeight = _a[1], inWidth = _a[2], inChannels = _a[3];
if (dataFormat === 'channelsLast') {
batchSize = inShape[0], inHeight = inShape[1], inWidth = inShape[2], inChannels = inShape[3];
}
else if (dataFormat === 'channelsFirst') {
batchSize = inShape[0], inChannels = inShape[1], inHeight = inShape[2], inWidth = inShape[3];
}
else {
throw new Error("Unknown dataFormat " + dataFormat);
}
var filterHeight = filterShape[0], filterWidth = filterShape[1], filterChannels = filterShape[3];
var _b = parseTupleParam(strides), strideHeight = _b[0], strideWidth = _b[1];
var _c = parseTupleParam(dilations), dilationHeight = _c[0], dilationWidth = _c[1];
var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
var _d = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outHeight = _d.outHeight, outWidth = _d.outWidth;
var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
var outShape;
if (dataFormat === 'channelsFirst') {
outShape = [batchSize, outChannels, outHeight, outWidth];
}
else if (dataFormat === 'channelsLast') {
outShape = [batchSize, outHeight, outWidth, outChannels];
}
return {
batchSize: batchSize,
dataFormat: dataFormat,
inHeight: inHeight,
inWidth: inWidth,
inChannels: inChannels,
outHeight: outHeight,
outWidth: outWidth,
outChannels: outChannels,
padInfo: padInfo,
strideHeight: strideHeight,
strideWidth: strideWidth,
filterHeight: filterHeight,
filterWidth: filterWidth,
effectiveFilterHeight: effectiveFilterHeight,
effectiveFilterWidth: effectiveFilterWidth,
dilationHeight: dilationHeight,
dilationWidth: dilationWidth,
inShape: inShape,
outShape: outShape,
filterShape: filterShape
};
}
exports.computeConv2DInfo = computeConv2DInfo;
/**
* Computes the information for a forward pass of a 3D convolution/pooling
* operation.
*/
function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise, dataFormat, roundingMode) {
if (depthwise === void 0) { depthwise = false; }
if (dataFormat === void 0) { dataFormat = 'channelsLast'; }
var _a = [-1, -1, -1, -1, -1], batchSize = _a[0], inDepth = _a[1], inHeight = _a[2], inWidth = _a[3], inChannels = _a[4];
if (dataFormat === 'channelsLast') {
batchSize = inShape[0], inDepth = inShape[1], inHeight = inShape[2], inWidth = inShape[3], inChannels = inShape[4];
}
else if (dataFormat === 'channelsFirst') {
batchSize = inShape[0], inChannels = inShape[1], inDepth = inShape[2], inHeight = inShape[3], inWidth = inShape[4];
}
else {
throw new Error("Unknown dataFormat " + dataFormat);
}
var filterDepth = filterShape[0], filterHeight = filterShape[1], filterWidth = filterShape[2], filterChannels = filterShape[4];
var _b = parse3TupleParam(strides), strideDepth = _b[0], strideHeight = _b[1], strideWidth = _b[2];
var _c = parse3TupleParam(dilations), dilationDepth = _c[0], dilationHeight = _c[1], dilationWidth = _c[2];
var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
var _d = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outDepth = _d.outDepth, outHeight = _d.outHeight, outWidth = _d.outWidth;
var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
var outShape;
if (dataFormat === 'channelsFirst') {
outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
}
else if (dataFormat === 'channelsLast') {
outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
}
return {
batchSize: batchSize,
dataFormat: dataFormat,
inDepth: inDepth,
inHeight: inHeight,
inWidth: inWidth,
inChannels: inChannels,
outDepth: outDepth,
outHeight: outHeight,
outWidth: outWidth,
outChannels: outChannels,
padInfo: padInfo,
strideDepth: strideDepth,
strideHeight: strideHeight,
strideWidth: strideWidth,
filterDepth: filterDepth,
filterHeight: filterHeight,
filterWidth: filterWidth,
effectiveFilterDepth: effectiveFilterDepth,
effectiveFilterHeight: effectiveFilterHeight,
effectiveFilterWidth: effectiveFilterWidth,
dilationDepth: dilationDepth,
dilationHeight: dilationHeight,
dilationWidth: dilationWidth,
inShape: inShape,
outShape: outShape,
filterShape: filterShape
};
}
exports.computeConv3DInfo = computeConv3DInfo;
function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
}
var inputRows = inShape[0];
var inputCols = inShape[1];
var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
util.assert(util.isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " +
"Change the stride and/or zero pad parameters"; });
var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
util.assert(util.isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " +
"Change the stride and/or zero pad parameters"; });
return [outputRows, outputCols];
}
function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
}
var inputDepth = inShape[0];
var inputRows = inShape[1];
var inputCols = inShape[2];
var outputDepths = conditionalRound((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
util.assert(util.isInt(outputDepths), function () { return "The output # of depths (" + outputDepths + ") must be an integer. " +
"Change the stride and/or zero pad parameters"; });
var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
util.assert(util.isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " +
"Change the stride and/or zero pad parameters"; });
var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
util.assert(util.isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " +
"Change the stride and/or zero pad parameters"; });
return [outputDepths, outputRows, outputCols, outChannels];
}
function computeDefaultPad(inputShape, fieldSize, stride, dilation) {
if (dilation === void 0) { dilation = 1; }
var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
}
exports.computeDefaultPad = computeDefaultPad;
function parseTupleParam(param) {
if (typeof param === 'number') {
return [param, param, param];
}
if (param.length === 2) {
return [param[0], param[1], 1];
}
return param;
}
function parse3TupleParam(param) {
return typeof param === 'number' ? [param, param, param] : param;
}
/* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d
* Atrous convolution is equivalent to standard convolution with upsampled
* filters with effective_filter_height =
* filter_height + (filter_height - 1) * (dilation - 1)
* and effective_filter_width =
* filter_width + (filter_width - 1) * (dilation - 1),
* produced by inserting dilation - 1 zeros along consecutive elements across
* the filters' spatial dimensions.
* When there is a dilation, this converts a filter dimension to the
* effective filter dimension, so it can be used in a standard convolution.
*/
function getEffectiveFilterSize(filterSize, dilation) {
if (dilation <= 1) {
return filterSize;
}
return filterSize + (filterSize - 1) * (dilation - 1);
}
function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode) {
var padInfo;
var outHeight;
var outWidth;
if (typeof pad === 'number') {
var padType = (pad === 0) ? 'VALID' : 'NUMBER';
padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType };
var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
outHeight = outShape[0];
outWidth = outShape[1];
}
else if (pad === 'same') {
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
var top_1 = Math.floor(padAlongHeight / 2);
var bottom = padAlongHeight - top_1;
var left = Math.floor(padAlongWidth / 2);
var right = padAlongWidth - left;
padInfo = { top: top_1, bottom: bottom, left: left, right: right, type: 'SAME' };
}
else if (pad === 'valid') {
padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' };
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
}
else {
throw Error("Unknown padding parameter: " + pad);
}
return { padInfo: padInfo, outHeight: outHeight, outWidth: outWidth };
}
function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
var padInfo;
var outDepth;
var outHeight;
var outWidth;
if (typeof pad === 'number') {
var padType = (pad === 0) ? 'VALID' : 'NUMBER';
padInfo = {
top: pad,
bottom: pad,
left: pad,
right: pad,
front: pad,
back: pad,
type: padType
};
var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode);
outDepth = outShape[0];
outHeight = outShape[1];
outWidth = outShape[2];
}
else if (pad === 'same') {
outDepth = Math.ceil(inDepth / strideDepth);
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
var front = Math.floor(padAlongDepth / 2);
var back = padAlongDepth - front;
var top_2 = Math.floor(padAlongHeight / 2);
var bottom = padAlongHeight - top_2;
var left = Math.floor(padAlongWidth / 2);
var right = padAlongWidth - left;
padInfo = { top: top_2, bottom: bottom, left: left, right: right, front: front, back: back, type: 'SAME' };
}
else if (pad === 'valid') {
padInfo = {
top: 0,
bottom: 0,
left: 0,
right: 0,
front: 0,
back: 0,
type: 'VALID'
};
outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
}
else {
throw Error("Unknown padding parameter: " + pad);
}
return { padInfo: padInfo, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth };
}
/**
* Rounds a value depending on the rounding mode
* @param value
* @param roundingMode
*/
function conditionalRound(value, roundingMode) {
if (!roundingMode) {
return value;
}
switch (roundingMode) {
case 'round':
// used for Caffe Conv
return Math.round(value);
case 'ceil':
// used for Caffe Pool
return Math.ceil(value);
case 'floor':
return Math.floor(value);
default:
throw new Error("Unknown roundingMode " + roundingMode);
}
}
function tupleValuesAreOne(param) {
var _a = parseTupleParam(param), dimA = _a[0], dimB = _a[1], dimC = _a[2];
return dimA === 1 && dimB === 1 && dimC === 1;
}
exports.tupleValuesAreOne = tupleValuesAreOne;
function eitherStridesOrDilationsAreOne(strides, dilations) {
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
}
exports.eitherStridesOrDilationsAreOne = eitherStridesOrDilationsAreOne;
/**
* Convert Conv2D dataFormat from 'NHWC'|'NCHW' to
* 'channelsLast'|'channelsFirst'
* @param dataFormat in 'NHWC'|'NCHW' mode
* @return dataFormat in 'channelsLast'|'channelsFirst' mode
* @throws unknown dataFormat
*/
function convertConv2DDataFormat(dataFormat) {
if (dataFormat === 'NHWC') {
return 'channelsLast';
}
else if (dataFormat === 'NCHW') {
return 'channelsFirst';
}
else {
throw new Error("Unknown dataFormat " + dataFormat);
}
}
exports.convertConv2DDataFormat = convertConv2DDataFormat;
//# sourceMappingURL=conv_util.js.map