onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
234 lines (214 loc) • 8.5 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { Logger } from '../../instrument';
import { assert } from '../../util';
/** Layout preferences */
export interface WidthHeightPrefs {
breakAxis?: number;
isPacked?: boolean;
reverseWH?: boolean;
}
/**
* TextureLayoutStrategy is an abstraction for different plans
* for mapping n-dimensional arrays to 2D textures (and back)
*/
export interface TextureLayoutStrategy {
computeTextureWH(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number];
}
/**
* This strategy try to find the minimal max(W,H) that fulfills (W * H == totalSize)
*/
export class AlwaysKeepOriginalSizeStrategy implements TextureLayoutStrategy {
constructor(public maxTextureSize: number) {}
computeTextureWH(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number] {
// scalar tensor
if (shape.length === 0) {
return [1, 1];
}
const maxTextureSize = this.maxTextureSize;
if (prefs && prefs.breakAxis !== undefined) {
// check to see if dims fit
const wsize = prefs.breakAxis >= shape.length ? 1 : shape.slice(prefs.breakAxis).reduce((a, b) => a * b);
const hsize = prefs.breakAxis <= 0 ? 1 : shape.slice(0, prefs.breakAxis).reduce((a, b) => a * b);
if (wsize > maxTextureSize || hsize > maxTextureSize) {
// ignore preferences
// continue with default layout
Logger.verbose(
'TextureLayout',
`Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`,
);
} else {
return [wsize, hsize];
}
}
const totalSize = shape.reduce((a, b) => a * b);
let width = Math.floor(Math.sqrt(totalSize));
for (; width < maxTextureSize && width < totalSize; width++) {
if (totalSize % width === 0) {
break;
}
}
if (width >= maxTextureSize || totalSize % width !== 0) {
throw new Error(`The given dimensions are outside this GPU's boundaries: ${shape}`);
}
return [width, totalSize / width];
}
}
export class PreferLogicalStrategy implements TextureLayoutStrategy {
constructor(public maxTextureSize: number) {}
computeTextureWH(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number] {
const wh = this.computeTexture(shape, prefs);
if (prefs && prefs.isPacked) {
wh[0] /= 2;
wh[1] /= 2;
}
if (prefs && prefs.reverseWH) {
return [wh[1], wh[0]];
}
return wh;
}
computeTexture(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number] {
const isPacked = prefs && prefs.isPacked;
// scalar tensor
if (shape.length === 0) {
return isPacked ? [2, 2] : [1, 1];
}
let maxTextureSize = this.maxTextureSize;
if (prefs && prefs.breakAxis !== undefined) {
// check to see if dims fit
const wsize = prefs.breakAxis >= shape.length ? 1 : shape.slice(prefs.breakAxis).reduce((a, b) => a * b);
const hsize = prefs.breakAxis <= 0 ? 1 : shape.slice(0, prefs.breakAxis).reduce((a, b) => a * b);
if (wsize > maxTextureSize || hsize > maxTextureSize) {
// ignore preferences
// continue with default layout
Logger.verbose(
'TextureLayout',
`Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`,
);
} else {
return [wsize, hsize];
}
}
let logShape = shape.slice(0);
if (isPacked) {
maxTextureSize = maxTextureSize * 2;
// This logic ensures we accurately count the number of packed texels needed
// to accommodate the tensor. We can only pack values in the same texel if
// they are from adjacent pairs of rows/cols within the same batch. So if a
// tensor has 3 rows, we pretend it has 4 rows in order to account for the
// fact that the texels containing the third row are half empty.
logShape = logShape.map((_d, i) =>
i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i],
);
// Packed texture height is at least 2 (the channel height of a single
// texel).
if (logShape.length === 1) {
logShape = [2, logShape[0]];
}
}
// If logical shape is 2, we don't squeeze, since we want to match physical.
if (logShape.length !== 2) {
const squeezeResult = squeezeShape(logShape);
logShape = squeezeResult.newShape;
}
const size = sizeFromShape(logShape);
if (logShape.length <= 1 && size <= maxTextureSize) {
return [1, size];
} else if (logShape.length === 2 && logShape[0] <= maxTextureSize && logShape[1] <= maxTextureSize) {
return logShape as [number, number];
} else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTextureSize && logShape[2] <= maxTextureSize) {
return [logShape[0] * logShape[1], logShape[2]];
} else if (logShape.length === 3 && logShape[0] <= maxTextureSize && logShape[1] * logShape[2] <= maxTextureSize) {
return [logShape[0], logShape[1] * logShape[2]];
} else if (
logShape.length === 4 &&
logShape[0] * logShape[1] * logShape[2] <= maxTextureSize &&
logShape[3] <= maxTextureSize
) {
return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
} else if (
logShape.length === 4 &&
logShape[0] <= maxTextureSize &&
logShape[1] * logShape[2] * logShape[3] <= maxTextureSize
) {
return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
} else {
if (isPacked) {
// For packed textures size equals the number of channels required to
// accommodate the texture data. However in order to squarify such that
// inner dimensions stay even, we rewrite size to equal the number of
// texels. Then in the return statement we rehydrate the squarified
// dimensions to channel units.
return sizeToSquarishShape(size / 4).map((d) => d * 2) as [number, number];
}
return sizeToSquarishShape(size);
}
}
}
export function squeezeShape(shape: number[], axis?: number[]): { newShape: number[]; keptDims: number[] } {
const newShape: number[] = [];
const keptDims: number[] = [];
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
const axes = axis == null || isEmptyArray ? null : parseAxisParam(axis, shape).sort();
let j = 0;
for (let i = 0; i < shape.length; ++i) {
if (axes != null) {
if (axes[j] === i && shape[i] !== 1) {
throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
}
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axes[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return { newShape, keptDims };
}
export function parseAxisParam(axis: number | number[], shape: number[]): number[] {
const rank = shape.length;
// Normalize input
axis = axis == null ? shape.map((_s, i) => i) : ([] as number[]).concat(axis);
// Check for valid range
assert(
axis.every((ax) => ax >= -rank && ax < rank),
() => `All values in axis param must be in range [-${rank}, ${rank}) but ` + `got axis ${axis}`,
);
// Check for only integers
assert(axis.every(isInt), () => 'All values in axis param must be integers but ' + `got axis ${axis}`);
// Handle negative axis.
return axis.map((a) => (a < 0 ? rank + a : a));
}
export function isInt(a: number): boolean {
return a % 1 === 0;
}
export function sizeFromShape(shape: number[]): number {
if (shape.length === 0) {
// Scalar.
return 1;
}
let size = shape[0];
for (let i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
export function getRowsCols(shape: number[]): [number, number] {
if (shape.length === 0) {
throw Error('Cannot get rows and columns of an empty shape array.');
}
return [shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]];
}
export function sizeToSquarishShape(size: number): [number, number] {
const width = Math.ceil(Math.sqrt(size));
return [width, Math.ceil(size / width)];
}
export function getBatchDim(shape: number[], dimsToSkip = 2): number {
return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
}