UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

440 lines (402 loc) 14.5 kB
/** * @license * Copyright 2017 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as util from '../util'; export type PadInfo = { top: number, left: number, right: number, bottom: number, type: string }; export type PadInfo3D = { top: number, left: number, right: number, bottom: number, front: number, back: number, type: string }; /** * Information about the forward pass of a convolution/pooling operation. * It includes input and output shape, strides, filter size and padding * information. */ export type Conv2DInfo = { batchSize: number, inHeight: number, inWidth: number, inChannels: number, outHeight: number, outWidth: number, outChannels: number, dataFormat: 'channelsFirst'|'channelsLast', strideHeight: number, strideWidth: number, dilationHeight: number, dilationWidth: number, filterHeight: number, filterWidth: number, effectiveFilterHeight: number, effectiveFilterWidth: number, padInfo: PadInfo, inShape: [number, number, number, number], outShape: [number, number, number, number], filterShape: [number, number, number, number] }; export function computePool2DInfo( inShape: [number, number, number, number], filterSize: [number, number]|number, strides: number|[number, number], dilations: number|[number, number], pad: 'same'|'valid'|number, roundingMode?: 'floor'|'round'|'ceil', dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv2DInfo { const [filterHeight, filterWidth] = parseTupleParam(filterSize); let filterShape: [number, number, number, number]; if (dataFormat === 'channelsLast') { filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]]; } else if (dataFormat === 'channelsFirst') { filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]]; } else { throw new Error(`Unknown dataFormat ${dataFormat}`); } return computeConv2DInfo( inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat); } /** * Computes the information for a forward pass of a convolution/pooling * operation. */ export function computeConv2DInfo( inShape: [number, number, number, number], filterShape: [number, number, number, number], strides: number|[number, number], dilations: number|[number, number], pad: 'same'|'valid'|number, roundingMode?: 'floor'|'round'|'ceil', depthwise = false, dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv2DInfo { let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1]; if (dataFormat === 'channelsLast') { [batchSize, inHeight, inWidth, inChannels] = inShape; } else if (dataFormat === 'channelsFirst') { [batchSize, inChannels, inHeight, inWidth] = inShape; } else { throw new Error(`Unknown dataFormat ${dataFormat}`); } const [filterHeight, filterWidth, , filterChannels] = filterShape; const [strideHeight, strideWidth] = parseTupleParam(strides); const [dilationHeight, dilationWidth] = parseTupleParam(dilations); const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); const {padInfo, outHeight, outWidth} = getPadAndOutInfo( pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode); const outChannels = depthwise ? filterChannels * inChannels : filterChannels; let outShape: [number, number, number, number]; if (dataFormat === 'channelsFirst') { outShape = [batchSize, outChannels, outHeight, outWidth]; } else if (dataFormat === 'channelsLast') { outShape = [batchSize, outHeight, outWidth, outChannels]; } return { batchSize, dataFormat, inHeight, inWidth, inChannels, outHeight, outWidth, outChannels, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, effectiveFilterHeight, effectiveFilterWidth, dilationHeight, dilationWidth, inShape, outShape, filterShape }; } /** * Information about the forward pass of a 3D convolution/pooling operation. * It includes input and output shape, strides, filter size and padding * information. */ export type Conv3DInfo = { batchSize: number, inDepth: number, inHeight: number, inWidth: number, inChannels: number, outDepth: number, outHeight: number, outWidth: number, outChannels: number, dataFormat: 'channelsFirst'|'channelsLast', strideDepth: number, strideHeight: number, strideWidth: number, dilationDepth: number, dilationHeight: number, dilationWidth: number, filterDepth: number, filterHeight: number, filterWidth: number, padInfo: PadInfo3D, inShape: [number, number, number, number, number], outShape: [number, number, number, number, number], filterShape: [number, number, number, number, number] }; /** * Computes the information for a forward pass of a 3D convolution/pooling * operation. */ export function computeConv3DInfo( inShape: [number, number, number, number, number], filterShape: [number, number, number, number, number], strides: number|[number, number, number], dilations: number|[number, number, number], pad: 'same'|'valid', depthwise = false, dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv3DInfo { let [batchSize, inDepth, inHeight, inWidth, inChannels] = [-1, -1, -1, -1, -1]; if (dataFormat === 'channelsLast') { [batchSize, inDepth, inHeight, inWidth, inChannels] = inShape; } else if (dataFormat === 'channelsFirst') { [batchSize, inChannels, inDepth, inHeight, inWidth] = inShape; } else { throw new Error(`Unknown dataFormat ${dataFormat}`); } const [filterDepth, filterHeight, filterWidth, , filterChannels] = filterShape; const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides); const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations); const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth); const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); const {padInfo, outDepth, outHeight, outWidth} = get3DPadAndOutInfo( pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth); const outChannels = depthwise ? filterChannels * inChannels : filterChannels; let outShape: [number, number, number, number, number]; if (dataFormat === 'channelsFirst') { outShape = [batchSize, outChannels, outDepth, outHeight, outWidth]; } else if (dataFormat === 'channelsLast') { outShape = [batchSize, outDepth, outHeight, outWidth, outChannels]; } return { batchSize, dataFormat, inDepth, inHeight, inWidth, inChannels, outDepth, outHeight, outWidth, outChannels, padInfo, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, dilationDepth, dilationHeight, dilationWidth, inShape, outShape, filterShape }; } function computeOutputShape2D( inShape: [number, number], fieldSize: number, stride: number, zeroPad?: number, roundingMode?: 'floor'|'round'|'ceil'): [number, number] { if (zeroPad == null) { zeroPad = computeDefaultPad(inShape, fieldSize, stride); } const inputRows = inShape[0]; const inputCols = inShape[1]; const outputRows = conditionalRound( (inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); util.assert( util.isInt(outputRows), () => `The output # of rows (${outputRows}) must be an integer. ` + `Change the stride and/or zero pad parameters`); const outputCols = conditionalRound( (inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); util.assert( util.isInt(outputCols), () => `The output # of columns (${outputCols}) must be an integer. ` + `Change the stride and/or zero pad parameters`); return [outputRows, outputCols]; } export function computeDefaultPad( inputShape: [number, number], fieldSize: number, stride: number, dilation = 1): number { const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation); return Math.floor( (inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2); } function parseTupleParam(param: number|[number, number]): [number, number] { return typeof param === 'number' ? [param, param] : param; } function parse3TupleParam(param: number|[number, number, number]): [number, number, number] { return typeof param === 'number' ? [param, param, param] : param; } /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d * Atrous convolution is equivalent to standard convolution with upsampled * filters with effective_filter_height = * filter_height + (filter_height - 1) * (dilation - 1) * and effective_filter_width = * filter_width + (filter_width - 1) * (dilation - 1), * produced by inserting dilation - 1 zeros along consecutive elements across * the filters' spatial dimensions. * When there is a dilation, this converts a filter dimension to the * effective filter dimension, so it can be used in a standard convolution. */ function getEffectiveFilterSize(filterSize: number, dilation: number) { if (dilation <= 1) { return filterSize; } return filterSize + (filterSize - 1) * (dilation - 1); } function getPadAndOutInfo( pad: 'same'|'valid'|number, inHeight: number, inWidth: number, strideHeight: number, strideWidth: number, filterHeight: number, filterWidth: number, roundingMode?: 'floor'|'round'|'ceil'): {padInfo: PadInfo, outHeight: number, outWidth: number} { let padInfo: PadInfo; let outHeight: number; let outWidth: number; if (typeof pad === 'number') { const padType = (pad === 0) ? 'VALID' : 'NUMBER'; padInfo = {top: pad, bottom: pad, left: pad, right: pad, type: padType}; const outShape = computeOutputShape2D( [inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode); outHeight = outShape[0]; outWidth = outShape[1]; } else if (pad === 'same') { outHeight = Math.ceil(inHeight / strideHeight); outWidth = Math.ceil(inWidth / strideWidth); const padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight); const padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth); const top = Math.floor(padAlongHeight / 2); const bottom = padAlongHeight - top; const left = Math.floor(padAlongWidth / 2); const right = padAlongWidth - left; padInfo = {top, bottom, left, right, type: 'SAME'}; } else if (pad === 'valid') { padInfo = {top: 0, bottom: 0, left: 0, right: 0, type: 'VALID'}; outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); } else { throw Error(`Unknown padding parameter: ${pad}`); } return {padInfo, outHeight, outWidth}; } function get3DPadAndOutInfo( pad: 'same'|'valid', inDepth: number, inHeight: number, inWidth: number, strideDepth: number, strideHeight: number, strideWidth: number, filterDepth: number, filterHeight: number, filterWidth: number): { padInfo: PadInfo3D, outDepth: number, outHeight: number, outWidth: number } { let padInfo: PadInfo3D; let outDepth: number; let outHeight: number; let outWidth: number; if (pad === 'same') { outDepth = Math.ceil(inDepth / strideDepth); outHeight = Math.ceil(inHeight / strideHeight); outWidth = Math.ceil(inWidth / strideWidth); const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth; const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; const front = Math.floor(padAlongDepth / 2); const back = padAlongDepth - front; const top = Math.floor(padAlongHeight / 2); const bottom = padAlongHeight - top; const left = Math.floor(padAlongWidth / 2); const right = padAlongWidth - left; padInfo = {top, bottom, left, right, front, back, type: 'SAME'}; } else if (pad === 'valid') { padInfo = { top: 0, bottom: 0, left: 0, right: 0, front: 0, back: 0, type: 'VALID' }; outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth); outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); } else { throw Error(`Unknown padding parameter: ${pad}`); } return {padInfo, outDepth, outHeight, outWidth}; } /** * Rounds a value depending on the rounding mode * @param value * @param roundingMode */ function conditionalRound( value: number, roundingMode?: 'floor'|'round'|'ceil') { if (!roundingMode) { return value; } switch (roundingMode) { case 'round': // used for Caffe Conv return Math.round(value); case 'ceil': // used for Caffe Pool return Math.ceil(value); case 'floor': return Math.floor(value); default: throw new Error(`Unknown roundingMode ${roundingMode}`); } } export function tupleValuesAreOne(param: number|[number, number]): boolean { const [dimA, dimB] = parseTupleParam(param); return dimA === 1 && dimB === 1; } export function eitherStridesOrDilationsAreOne( strides: number|[number, number], dilations: number|[number, number]): boolean { return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); }