UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

253 lines 11.7 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var util = require("../util"); function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) { if (dataFormat === void 0) { dataFormat = 'channelsLast'; } var _a = parseTupleParam(filterSize), filterHeight = _a[0], filterWidth = _a[1]; var filterShape; if (dataFormat === 'channelsLast') { filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]]; } else if (dataFormat === 'channelsFirst') { filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]]; } else { throw new Error("Unknown dataFormat " + dataFormat); } return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat); } exports.computePool2DInfo = computePool2DInfo; function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise, dataFormat) { if (depthwise === void 0) { depthwise = false; } if (dataFormat === void 0) { dataFormat = 'channelsLast'; } var _a = [-1, -1, -1, -1], batchSize = _a[0], inHeight = _a[1], inWidth = _a[2], inChannels = _a[3]; if (dataFormat === 'channelsLast') { batchSize = inShape[0], inHeight = inShape[1], inWidth = inShape[2], inChannels = inShape[3]; } else if (dataFormat === 'channelsFirst') { batchSize = inShape[0], inChannels = inShape[1], inHeight = inShape[2], inWidth = inShape[3]; } else { throw new Error("Unknown dataFormat " + dataFormat); } var filterHeight = filterShape[0], filterWidth = filterShape[1], filterChannels = filterShape[3]; var _b = parseTupleParam(strides), strideHeight = _b[0], strideWidth = _b[1]; var _c = parseTupleParam(dilations), dilationHeight = _c[0], dilationWidth = _c[1]; var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); var _d = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outHeight = _d.outHeight, outWidth = _d.outWidth; var outChannels = depthwise ? filterChannels * inChannels : filterChannels; var outShape; if (dataFormat === 'channelsFirst') { outShape = [batchSize, outChannels, outHeight, outWidth]; } else if (dataFormat === 'channelsLast') { outShape = [batchSize, outHeight, outWidth, outChannels]; } return { batchSize: batchSize, dataFormat: dataFormat, inHeight: inHeight, inWidth: inWidth, inChannels: inChannels, outHeight: outHeight, outWidth: outWidth, outChannels: outChannels, padInfo: padInfo, strideHeight: strideHeight, strideWidth: strideWidth, filterHeight: filterHeight, filterWidth: filterWidth, effectiveFilterHeight: effectiveFilterHeight, effectiveFilterWidth: effectiveFilterWidth, dilationHeight: dilationHeight, dilationWidth: dilationWidth, inShape: inShape, outShape: outShape, filterShape: filterShape }; } exports.computeConv2DInfo = computeConv2DInfo; function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise, dataFormat) { if (depthwise === void 0) { depthwise = false; } if (dataFormat === void 0) { dataFormat = 'channelsLast'; } var _a = [-1, -1, -1, -1, -1], batchSize = _a[0], inDepth = _a[1], inHeight = _a[2], inWidth = _a[3], inChannels = _a[4]; if (dataFormat === 'channelsLast') { batchSize = inShape[0], inDepth = inShape[1], inHeight = inShape[2], inWidth = inShape[3], inChannels = inShape[4]; } else if (dataFormat === 'channelsFirst') { batchSize = inShape[0], inChannels = inShape[1], inDepth = inShape[2], inHeight = inShape[3], inWidth = inShape[4]; } else { throw new Error("Unknown dataFormat " + dataFormat); } var filterDepth = filterShape[0], filterHeight = filterShape[1], filterWidth = filterShape[2], filterChannels = filterShape[4]; var _b = parse3TupleParam(strides), strideDepth = _b[0], strideHeight = _b[1], strideWidth = _b[2]; var _c = parse3TupleParam(dilations), dilationDepth = _c[0], dilationHeight = _c[1], dilationWidth = _c[2]; var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth); var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); var _d = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth), padInfo = _d.padInfo, outDepth = _d.outDepth, outHeight = _d.outHeight, outWidth = _d.outWidth; var outChannels = depthwise ? filterChannels * inChannels : filterChannels; var outShape; if (dataFormat === 'channelsFirst') { outShape = [batchSize, outChannels, outDepth, outHeight, outWidth]; } else if (dataFormat === 'channelsLast') { outShape = [batchSize, outDepth, outHeight, outWidth, outChannels]; } return { batchSize: batchSize, dataFormat: dataFormat, inDepth: inDepth, inHeight: inHeight, inWidth: inWidth, inChannels: inChannels, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth, outChannels: outChannels, padInfo: padInfo, strideDepth: strideDepth, strideHeight: strideHeight, strideWidth: strideWidth, filterDepth: filterDepth, filterHeight: filterHeight, filterWidth: filterWidth, dilationDepth: dilationDepth, dilationHeight: dilationHeight, dilationWidth: dilationWidth, inShape: inShape, outShape: outShape, filterShape: filterShape }; } exports.computeConv3DInfo = computeConv3DInfo; function computeOutputShape3D(inShape, fieldSize, outDepth, stride, zeroPad, roundingMode) { if (zeroPad == null) { zeroPad = computeDefaultPad(inShape, fieldSize, stride); } var inputRows = inShape[0]; var inputCols = inShape[1]; var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); util.assert(util.isInt(outputRows), "The output # of rows (" + outputRows + ") must be an integer. Change the " + "stride and/or zero pad parameters"); var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); util.assert(util.isInt(outputCols), "The output # of columns (" + outputCols + ") must be an integer. Change " + "the stride and/or zero pad parameters"); return [outputRows, outputCols, outDepth]; } function computeDefaultPad(inputShape, fieldSize, stride, dilation) { if (dilation === void 0) { dilation = 1; } var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation); return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2); } exports.computeDefaultPad = computeDefaultPad; function parseTupleParam(param) { return typeof param === 'number' ? [param, param] : param; } function parse3TupleParam(param) { return typeof param === 'number' ? [param, param, param] : param; } function getEffectiveFilterSize(filterSize, dilation) { if (dilation <= 1) { return filterSize; } return filterSize + (filterSize - 1) * (dilation - 1); } function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode) { var padInfo; var outHeight; var outWidth; if (typeof pad === 'number') { var padType = (pad === 0) ? 'VALID' : 'NUMBER'; padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType }; var outShape = computeOutputShape3D([inHeight, inWidth, 1], filterHeight, 1, strideHeight, pad, roundingMode); outHeight = outShape[0]; outWidth = outShape[1]; } else if (pad === 'same') { outHeight = Math.ceil(inHeight / strideHeight); outWidth = Math.ceil(inWidth / strideWidth); var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; var top_1 = Math.floor(padAlongHeight / 2); var bottom = padAlongHeight - top_1; var left = Math.floor(padAlongWidth / 2); var right = padAlongWidth - left; padInfo = { top: top_1, bottom: bottom, left: left, right: right, type: 'SAME' }; } else if (pad === 'valid') { padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' }; outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); } else { throw Error("Unknown padding parameter: " + pad); } return { padInfo: padInfo, outHeight: outHeight, outWidth: outWidth }; } function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth) { var padInfo; var outDepth; var outHeight; var outWidth; if (pad === 'same') { outDepth = Math.ceil(inDepth / strideDepth); outHeight = Math.ceil(inHeight / strideHeight); outWidth = Math.ceil(inWidth / strideWidth); var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth; var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; var front = Math.floor(padAlongDepth / 2); var back = padAlongDepth - front; var top_2 = Math.floor(padAlongHeight / 2); var bottom = padAlongHeight - top_2; var left = Math.floor(padAlongWidth / 2); var right = padAlongWidth - left; padInfo = { top: top_2, bottom: bottom, left: left, right: right, front: front, back: back, type: 'SAME' }; } else if (pad === 'valid') { padInfo = { top: 0, bottom: 0, left: 0, right: 0, front: 0, back: 0, type: 'VALID' }; outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth); outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); } else { throw Error("Unknown padding parameter: " + pad); } return { padInfo: padInfo, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth }; } function conditionalRound(value, roundingMode) { if (!roundingMode) { return value; } switch (roundingMode) { case 'round': return Math.round(value); case 'ceil': return Math.ceil(value); case 'floor': return Math.floor(value); default: throw new Error("Unknown roundingMode " + roundingMode); } } function tupleValuesAreOne(param) { var _a = parseTupleParam(param), dimA = _a[0], dimB = _a[1]; return dimA === 1 && dimB === 1; } exports.tupleValuesAreOne = tupleValuesAreOne; function eitherStridesOrDilationsAreOne(strides, dilations) { return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); } exports.eitherStridesOrDilationsAreOne = eitherStridesOrDilationsAreOne; //# sourceMappingURL=conv_util.js.map