@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
440 lines (402 loc) • 14.5 kB
text/typescript
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as util from '../util';
export type PadInfo = {
top: number,
left: number,
right: number,
bottom: number,
type: string
};
export type PadInfo3D = {
top: number,
left: number,
right: number,
bottom: number,
front: number,
back: number,
type: string
};
/**
* Information about the forward pass of a convolution/pooling operation.
* It includes input and output shape, strides, filter size and padding
* information.
*/
export type Conv2DInfo = {
batchSize: number,
inHeight: number,
inWidth: number,
inChannels: number,
outHeight: number,
outWidth: number,
outChannels: number,
dataFormat: 'channelsFirst'|'channelsLast',
strideHeight: number,
strideWidth: number,
dilationHeight: number,
dilationWidth: number,
filterHeight: number,
filterWidth: number,
effectiveFilterHeight: number,
effectiveFilterWidth: number,
padInfo: PadInfo,
inShape: [number, number, number, number],
outShape: [number, number, number, number],
filterShape: [number, number, number, number]
};
export function computePool2DInfo(
inShape: [number, number, number, number],
filterSize: [number, number]|number, strides: number|[number, number],
dilations: number|[number, number], pad: 'same'|'valid'|number,
roundingMode?: 'floor'|'round'|'ceil',
dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv2DInfo {
const [filterHeight, filterWidth] = parseTupleParam(filterSize);
let filterShape: [number, number, number, number];
if (dataFormat === 'channelsLast') {
filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
} else if (dataFormat === 'channelsFirst') {
filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
} else {
throw new Error(`Unknown dataFormat ${dataFormat}`);
}
return computeConv2DInfo(
inShape, filterShape, strides, dilations, pad, roundingMode, false,
dataFormat);
}
/**
* Computes the information for a forward pass of a convolution/pooling
* operation.
*/
export function computeConv2DInfo(
inShape: [number, number, number, number],
filterShape: [number, number, number, number],
strides: number|[number, number], dilations: number|[number, number],
pad: 'same'|'valid'|number, roundingMode?: 'floor'|'round'|'ceil',
depthwise = false,
dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv2DInfo {
let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1];
if (dataFormat === 'channelsLast') {
[batchSize, inHeight, inWidth, inChannels] = inShape;
} else if (dataFormat === 'channelsFirst') {
[batchSize, inChannels, inHeight, inWidth] = inShape;
} else {
throw new Error(`Unknown dataFormat ${dataFormat}`);
}
const [filterHeight, filterWidth, , filterChannels] = filterShape;
const [strideHeight, strideWidth] = parseTupleParam(strides);
const [dilationHeight, dilationWidth] = parseTupleParam(dilations);
const effectiveFilterHeight =
getEffectiveFilterSize(filterHeight, dilationHeight);
const effectiveFilterWidth =
getEffectiveFilterSize(filterWidth, dilationWidth);
const {padInfo, outHeight, outWidth} = getPadAndOutInfo(
pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight,
effectiveFilterWidth, roundingMode);
const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
let outShape: [number, number, number, number];
if (dataFormat === 'channelsFirst') {
outShape = [batchSize, outChannels, outHeight, outWidth];
} else if (dataFormat === 'channelsLast') {
outShape = [batchSize, outHeight, outWidth, outChannels];
}
return {
batchSize,
dataFormat,
inHeight,
inWidth,
inChannels,
outHeight,
outWidth,
outChannels,
padInfo,
strideHeight,
strideWidth,
filterHeight,
filterWidth,
effectiveFilterHeight,
effectiveFilterWidth,
dilationHeight,
dilationWidth,
inShape,
outShape,
filterShape
};
}
/**
* Information about the forward pass of a 3D convolution/pooling operation.
* It includes input and output shape, strides, filter size and padding
* information.
*/
export type Conv3DInfo = {
batchSize: number,
inDepth: number,
inHeight: number,
inWidth: number,
inChannels: number,
outDepth: number,
outHeight: number,
outWidth: number,
outChannels: number,
dataFormat: 'channelsFirst'|'channelsLast',
strideDepth: number,
strideHeight: number,
strideWidth: number,
dilationDepth: number,
dilationHeight: number,
dilationWidth: number,
filterDepth: number,
filterHeight: number,
filterWidth: number,
padInfo: PadInfo3D,
inShape: [number, number, number, number, number],
outShape: [number, number, number, number, number],
filterShape: [number, number, number, number, number]
};
/**
* Computes the information for a forward pass of a 3D convolution/pooling
* operation.
*/
export function computeConv3DInfo(
inShape: [number, number, number, number, number],
filterShape: [number, number, number, number, number],
strides: number|[number, number, number],
dilations: number|[number, number, number], pad: 'same'|'valid',
depthwise = false,
dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv3DInfo {
let [batchSize, inDepth, inHeight, inWidth, inChannels] =
[-1, -1, -1, -1, -1];
if (dataFormat === 'channelsLast') {
[batchSize, inDepth, inHeight, inWidth, inChannels] = inShape;
} else if (dataFormat === 'channelsFirst') {
[batchSize, inChannels, inDepth, inHeight, inWidth] = inShape;
} else {
throw new Error(`Unknown dataFormat ${dataFormat}`);
}
const [filterDepth, filterHeight, filterWidth, , filterChannels] =
filterShape;
const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides);
const [dilationDepth, dilationHeight, dilationWidth] =
parse3TupleParam(dilations);
const effectiveFilterDepth =
getEffectiveFilterSize(filterDepth, dilationDepth);
const effectiveFilterHeight =
getEffectiveFilterSize(filterHeight, dilationHeight);
const effectiveFilterWidth =
getEffectiveFilterSize(filterWidth, dilationWidth);
const {padInfo, outDepth, outHeight, outWidth} = get3DPadAndOutInfo(
pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth,
effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth);
const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
let outShape: [number, number, number, number, number];
if (dataFormat === 'channelsFirst') {
outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
} else if (dataFormat === 'channelsLast') {
outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
}
return {
batchSize,
dataFormat,
inDepth,
inHeight,
inWidth,
inChannels,
outDepth,
outHeight,
outWidth,
outChannels,
padInfo,
strideDepth,
strideHeight,
strideWidth,
filterDepth,
filterHeight,
filterWidth,
dilationDepth,
dilationHeight,
dilationWidth,
inShape,
outShape,
filterShape
};
}
function computeOutputShape2D(
inShape: [number, number], fieldSize: number, stride: number,
zeroPad?: number, roundingMode?: 'floor'|'round'|'ceil'): [number, number] {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
}
const inputRows = inShape[0];
const inputCols = inShape[1];
const outputRows = conditionalRound(
(inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
util.assert(
util.isInt(outputRows),
() => `The output # of rows (${outputRows}) must be an integer. ` +
`Change the stride and/or zero pad parameters`);
const outputCols = conditionalRound(
(inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
util.assert(
util.isInt(outputCols),
() => `The output # of columns (${outputCols}) must be an integer. ` +
`Change the stride and/or zero pad parameters`);
return [outputRows, outputCols];
}
export function computeDefaultPad(
inputShape: [number, number], fieldSize: number, stride: number,
dilation = 1): number {
const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
return Math.floor(
(inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
}
function parseTupleParam(param: number|[number, number]): [number, number] {
return typeof param === 'number' ? [param, param] : param;
}
function parse3TupleParam(param: number|[number, number, number]):
[number, number, number] {
return typeof param === 'number' ? [param, param, param] : param;
}
/* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d
* Atrous convolution is equivalent to standard convolution with upsampled
* filters with effective_filter_height =
* filter_height + (filter_height - 1) * (dilation - 1)
* and effective_filter_width =
* filter_width + (filter_width - 1) * (dilation - 1),
* produced by inserting dilation - 1 zeros along consecutive elements across
* the filters' spatial dimensions.
* When there is a dilation, this converts a filter dimension to the
* effective filter dimension, so it can be used in a standard convolution.
*/
function getEffectiveFilterSize(filterSize: number, dilation: number) {
if (dilation <= 1) {
return filterSize;
}
return filterSize + (filterSize - 1) * (dilation - 1);
}
function getPadAndOutInfo(
pad: 'same'|'valid'|number, inHeight: number, inWidth: number,
strideHeight: number, strideWidth: number, filterHeight: number,
filterWidth: number, roundingMode?: 'floor'|'round'|'ceil'):
{padInfo: PadInfo, outHeight: number, outWidth: number} {
let padInfo: PadInfo;
let outHeight: number;
let outWidth: number;
if (typeof pad === 'number') {
const padType = (pad === 0) ? 'VALID' : 'NUMBER';
padInfo = {top: pad, bottom: pad, left: pad, right: pad, type: padType};
const outShape = computeOutputShape2D(
[inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
outHeight = outShape[0];
outWidth = outShape[1];
} else if (pad === 'same') {
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
const padAlongHeight =
Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
const padAlongWidth =
Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
const top = Math.floor(padAlongHeight / 2);
const bottom = padAlongHeight - top;
const left = Math.floor(padAlongWidth / 2);
const right = padAlongWidth - left;
padInfo = {top, bottom, left, right, type: 'SAME'};
} else if (pad === 'valid') {
padInfo = {top: 0, bottom: 0, left: 0, right: 0, type: 'VALID'};
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
} else {
throw Error(`Unknown padding parameter: ${pad}`);
}
return {padInfo, outHeight, outWidth};
}
function get3DPadAndOutInfo(
pad: 'same'|'valid', inDepth: number, inHeight: number, inWidth: number,
strideDepth: number, strideHeight: number, strideWidth: number,
filterDepth: number, filterHeight: number, filterWidth: number): {
padInfo: PadInfo3D,
outDepth: number,
outHeight: number,
outWidth: number
} {
let padInfo: PadInfo3D;
let outDepth: number;
let outHeight: number;
let outWidth: number;
if (pad === 'same') {
outDepth = Math.ceil(inDepth / strideDepth);
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
const padAlongHeight =
(outHeight - 1) * strideHeight + filterHeight - inHeight;
const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
const front = Math.floor(padAlongDepth / 2);
const back = padAlongDepth - front;
const top = Math.floor(padAlongHeight / 2);
const bottom = padAlongHeight - top;
const left = Math.floor(padAlongWidth / 2);
const right = padAlongWidth - left;
padInfo = {top, bottom, left, right, front, back, type: 'SAME'};
} else if (pad === 'valid') {
padInfo = {
top: 0,
bottom: 0,
left: 0,
right: 0,
front: 0,
back: 0,
type: 'VALID'
};
outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
} else {
throw Error(`Unknown padding parameter: ${pad}`);
}
return {padInfo, outDepth, outHeight, outWidth};
}
/**
* Rounds a value depending on the rounding mode
* @param value
* @param roundingMode
*/
function conditionalRound(
value: number, roundingMode?: 'floor'|'round'|'ceil') {
if (!roundingMode) {
return value;
}
switch (roundingMode) {
case 'round':
// used for Caffe Conv
return Math.round(value);
case 'ceil':
// used for Caffe Pool
return Math.ceil(value);
case 'floor':
return Math.floor(value);
default:
throw new Error(`Unknown roundingMode ${roundingMode}`);
}
}
export function tupleValuesAreOne(param: number|[number, number]): boolean {
const [dimA, dimB] = parseTupleParam(param);
return dimA === 1 && dimB === 1;
}
export function eitherStridesOrDilationsAreOne(
strides: number|[number, number],
dilations: number|[number, number]): boolean {
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
}