@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
1,354 lines (1,197 loc) • 104 kB
text/typescript
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Import webgl flags.
import './flags_webgl';
import * as device_util from '../../device_util';
import {ENGINE, MemoryInfo, TimingInfo} from '../../engine';
import {env} from '../../environment';
import {tidy} from '../../globals';
import {TensorInfo} from '../../kernel_registry';
import {warn} from '../../log';
import {buffer} from '../../ops/array_ops';
import * as array_ops_util from '../../ops/array_ops_util';
import * as axis_util from '../../ops/axis_util';
import {complex, imag, real} from '../../ops/complex_ops';
import {computeOutShape} from '../../ops/concat_util';
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';
import * as gather_nd_util from '../../ops/gather_nd_util';
import * as reduce_util from '../../ops/reduce_util';
import * as scatter_nd_util from '../../ops/scatter_nd_util';
import * as segment_util from '../../ops/segment_util';
import * as slice_util from '../../ops/slice_util';
import {softmax} from '../../ops/softmax';
import {range, scalar, tensor} from '../../ops/tensor_ops';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
import {BackendValues, DataType, DataTypeMap, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util';
import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend';
import * as backend_util from '../backend_util';
import {mergeRealAndImagArrays} from '../complex_util';
import {nonMaxSuppressionV3} from '../non_max_suppression_impl';
import {split} from '../split_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';
import {AddNProgram} from './addn_gpu';
import {AddNPackedProgram} from './addn_packed_gpu';
import {ArgMinMaxProgram} from './argminmax_gpu';
import {ArgMinMaxPackedProgram} from './argminmax_packed_gpu';
import {AvgPool2DBackpropProgram, AvgPool3DBackpropProgram} from './avg_pool_backprop_gpu';
import {BatchNormProgram} from './batchnorm_gpu';
import {BatchNormPackedProgram} from './batchnorm_packed_gpu';
import * as binaryop_complex_gpu from './binaryop_complex_gpu';
import {BinaryOpComplexProgram} from './binaryop_complex_gpu';
import * as binaryop_gpu from './binaryop_gpu';
import {BinaryOpProgram} from './binaryop_gpu';
import * as binaryop_packed_gpu from './binaryop_packed_gpu';
import {BinaryOpPackedProgram} from './binaryop_packed_gpu';
import {getWebGLContext} from './canvas_util';
import {ClipProgram} from './clip_gpu';
import {ClipPackedProgram} from './clip_packed_gpu';
import {ComplexAbsProgram} from './complex_abs_gpu';
import {ConcatProgram} from './concat_gpu';
import {ConcatPackedProgram} from './concat_packed_gpu';
import {Conv2DDerFilterProgram, Conv2DDerInputProgram, Conv3DDerFilterProgram, Conv3DDerInputProgram} from './conv_backprop_gpu';
import {DepthwiseConv2DDerFilterProgram, DepthwiseConv2DDerInputProgram} from './conv_backprop_gpu_depthwise';
import {Conv2DProgram, Conv3DProgram} from './conv_gpu';
import {DepthwiseConv2DProgram} from './conv_gpu_depthwise';
import {DepthwiseConvPacked2DProgram} from './conv_packed_gpu_depthwise';
import {CropAndResizeProgram} from './crop_and_resize_gpu';
import {CumSumProgram} from './cumsum_gpu';
import {DecodeMatrixProgram} from './decode_matrix_gpu';
import {DecodeMatrixPackedProgram} from './decode_matrix_packed_gpu';
import {DepthToSpaceProgram} from './depth_to_space_gpu';
import {DiagProgram} from './diag_gpu';
import {EncodeFloatProgram} from './encode_float_gpu';
import {EncodeFloatPackedProgram} from './encode_float_packed_gpu';
import {EncodeMatrixProgram} from './encode_matrix_gpu';
import {EncodeMatrixPackedProgram} from './encode_matrix_packed_gpu';
import * as fft_gpu from './fft_gpu';
import {FFTProgram} from './fft_gpu';
import {FillProgram} from './fill_gpu';
import {GatherProgram} from './gather_gpu';
import {GatherNDProgram} from './gather_nd_gpu';
import {GPGPUContext} from './gpgpu_context';
import * as gpgpu_math from './gpgpu_math';
import {GPGPUBinary, GPGPUProgram, TensorData} from './gpgpu_math';
import {Im2ColPackedProgram} from './im2col_packed_gpu';
import {LRNProgram} from './lrn_gpu';
import {LRNGradProgram} from './lrn_grad_gpu';
import {LRNPackedProgram} from './lrn_packed_gpu';
import {MaxPool2DBackpropProgram, MaxPool3DBackpropProgram} from './max_pool_backprop_gpu';
import {MatMulPackedProgram} from './mulmat_packed_gpu';
import {MultinomialProgram} from './multinomial_gpu';
import {OneHotProgram} from './onehot_gpu';
import {PackProgram} from './pack_gpu';
import {PadProgram} from './pad_gpu';
import {PadPackedProgram} from './pad_packed_gpu';
import {Pool2DProgram, Pool3DProgram} from './pool_gpu';
import {ReduceProgram} from './reduce_gpu';
import {ReshapePackedProgram} from './reshape_packed_gpu';
import {ResizeBilinearBackpropProgram} from './resize_bilinear_backprop_gpu';
import {ResizeBilinearProgram} from './resize_bilinear_gpu';
import {ResizeBilinearPackedProgram} from './resize_bilinear_packed_gpu';
import {ResizeNearestNeigborBackpropProgram} from './resize_nearest_neighbor_backprop_gpu';
import {ResizeNearestNeighborProgram} from './resize_nearest_neighbor_gpu';
import {ReverseProgram} from './reverse_gpu';
import {ReversePackedProgram} from './reverse_packed_gpu';
import {ScatterProgram} from './scatter_gpu';
import {SegmentOpProgram} from './segment_gpu';
import {SelectProgram} from './select_gpu';
import {SliceProgram} from './slice_gpu';
import {SlicePackedProgram} from './slice_packed_gpu';
import {StridedSliceProgram} from './strided_slice_gpu';
import * as tex_util from './tex_util';
import {TextureData, TextureUsage} from './tex_util';
import {TextureManager} from './texture_manager';
import {TileProgram} from './tile_gpu';
import {TransposeProgram} from './transpose_gpu';
import {TransposePackedProgram} from './transpose_packed_gpu';
import * as unary_op from './unaryop_gpu';
import {UnaryOpProgram} from './unaryop_gpu';
import * as unary_packed_op from './unaryop_packed_gpu';
import {UnaryOpPackedProgram} from './unaryop_packed_gpu';
import {UnpackProgram} from './unpack_gpu';
import * as webgl_util from './webgl_util';
type KernelInfo = {
name: string; query: Promise<number>;
};
export type TimerNode = RecursiveArray<KernelInfo>|KernelInfo;
export interface CPUTimerQuery {
startMs: number;
endMs?: number;
}
export interface WebGLMemoryInfo extends MemoryInfo {
numBytesInGPU: number;
unreliable: boolean;
}
export interface WebGLTimingInfo extends TimingInfo {
uploadWaitMs: number;
downloadWaitMs: number;
}
const binaryCaches: {[webGLVersion: string]: {[key: string]: GPGPUBinary}} = {};
export function getBinaryCache(webGLVersion: number) {
if (webGLVersion in binaryCaches) {
return binaryCaches[webGLVersion];
}
binaryCaches[webGLVersion] = {};
return binaryCaches[webGLVersion];
}
function mapActivationToShaderProgram(
activation: Activation, packed = false): string {
if (activation === 'linear') {
if (packed) {
return unary_packed_op.LINEAR;
}
return unary_op.LINEAR;
} else if (activation === 'relu') {
if (packed) {
return unary_packed_op.RELU;
}
return unary_op.RELU;
} else if (activation === 'elu') {
if (packed) {
return unary_packed_op.ELU;
}
return unary_op.ELU;
} else if (activation === 'relu6') {
if (packed) {
return unary_packed_op.RELU6;
}
return unary_op.RELU6;
} else if (activation === 'prelu') {
if (packed) {
return binaryop_packed_gpu.PRELU;
}
return binaryop_gpu.PRELU;
}
throw new Error(`Activation ${
activation} has not been implemented for the WebGL backend.`);
}
// Empirically determined constant used to determine size threshold for handing
// off execution to the CPU.
const CPU_HANDOFF_SIZE_THRESHOLD = 128;
// Empirically determined constant used to decide the number of MB on GPU
// before we warn about high memory use. The MB are this constant * screen area
// * dpi / 1024 / 1024.
const BEFORE_PAGING_CONSTANT = 600;
function numMBBeforeWarning(): number {
if (env().global.screen == null) {
return 1024; // 1 GB.
}
return (env().global.screen.height * env().global.screen.width *
window.devicePixelRatio) *
BEFORE_PAGING_CONSTANT / 1024 / 1024;
}
// Empirically determined minimal shared dimension in matmul before we forward
// to a.mul(b).sum() in order to take advantage of GPU parallelism. See
// https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.
export const MATMUL_SHARED_DIM_THRESHOLD = 1000;
export class MathBackendWebGL extends KernelBackend {
texData: DataStorage<TextureData>;
gpgpu: GPGPUContext;
// Maps data ids that have a pending read operation, to list of subscribers.
private pendingRead = new WeakMap<DataId, Array<(arr: TypedArray) => void>>();
// List of data ids that are scheduled for disposal, but are waiting on a
// pending read operation.
private pendingDisposal = new WeakSet<DataId>();
// Used to count the number of 'shallow' sliced tensors that point to the
// same data id.
private dataRefCount = new WeakMap<DataId, number>();
private numBytesInGPU = 0;
private canvas: HTMLCanvasElement|OffscreenCanvas;
private programTimersStack: TimerNode[];
private activeTimers: TimerNode[];
// Accumulated time spent (including blocking) in uploading data to webgl.
private uploadWaitMs = 0;
// Accumulated time spent (including blocking in downloading data from webgl.
private downloadWaitMs = 0;
private cpuBackend: KernelBackend;
// Number of bits of precision of this backend.
private floatPrecisionValue: 32|16;
private textureManager: TextureManager;
private binaryCache: {[key: string]: GPGPUBinary};
private gpgpuCreatedLocally: boolean;
private numMBBeforeWarning: number;
private warnedAboutMemory = false;
constructor(gpgpu?: GPGPUContext) {
super();
if (!env().getBool('HAS_WEBGL')) {
throw new Error('WebGL is not supported on this device');
}
if (gpgpu == null) {
const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'));
this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION'));
this.gpgpu = new GPGPUContext(gl);
this.canvas = gl.canvas;
this.gpgpuCreatedLocally = true;
} else {
this.gpgpu = gpgpu;
this.binaryCache = {};
this.gpgpuCreatedLocally = false;
this.canvas = gpgpu.gl.canvas;
}
this.textureManager = new TextureManager(this.gpgpu);
this.numMBBeforeWarning = numMBBeforeWarning();
this.texData = new DataStorage(this, ENGINE);
}
numDataIds() {
return this.texData.numDataIds() +
(this.cpuBackend ? this.cpuBackend.numDataIds() : 0) -
this.pendingDeletes;
}
write(values: BackendValues, shape: number[], dtype: DataType): DataId {
if (env().getBool('DEBUG')) {
this.checkNumericalProblems(values);
}
if (dtype === 'complex64' && values != null) {
throw new Error(
`Cannot write to a complex64 dtype. ` +
`Please use tf.complex(real, imag).`);
}
const dataId = {};
this.texData.set(
dataId, {shape, dtype, values, usage: TextureUsage.UPLOAD});
return dataId;
}
move(dataId: DataId, values: BackendValues, shape: number[], dtype: DataType):
void {
if (env().getBool('DEBUG')) {
this.checkNumericalProblems(values);
}
if (dtype === 'complex64') {
throw new Error(
`Cannot write to a complex64 dtype. ` +
`Please use tf.complex(real, imag).`);
}
this.texData.set(
dataId, {shape, dtype, values, usage: TextureUsage.UPLOAD});
}
readSync(dataId: DataId): BackendValues {
const texData = this.texData.get(dataId);
const {values, dtype, complexTensors, slice, shape, isPacked} = texData;
if (slice != null) {
let program;
if (isPacked) {
program = new UnaryOpPackedProgram(shape, unary_op.CLONE);
} else {
program = new UnaryOpProgram(shape, unary_op.CLONE);
}
const res =
this.runWebGLProgram(program, [{dataId, shape, dtype}], dtype);
const data = this.readSync(res.dataId);
this.disposeData(res.dataId);
return data;
}
if (values != null) {
return this.convertAndCacheOnCPU(dataId);
}
if (dtype === 'string') {
return values;
}
const shouldTimeProgram = this.activeTimers != null;
let start: number;
if (shouldTimeProgram) {
start = util.now();
}
let result: Float32Array;
if (dtype === 'complex64') {
const realValues = complexTensors.real.dataSync() as Float32Array;
const imagValues = complexTensors.imag.dataSync() as Float32Array;
result = mergeRealAndImagArrays(realValues, imagValues);
} else {
result = this.getValuesFromTexture(dataId);
}
if (shouldTimeProgram) {
this.downloadWaitMs += util.now() - start;
}
return this.convertAndCacheOnCPU(dataId, result);
}
async read(dataId: DataId): Promise<BackendValues> {
if (this.pendingRead.has(dataId)) {
const subscribers = this.pendingRead.get(dataId);
return new Promise<TypedArray>(resolve => subscribers.push(resolve));
}
const texData = this.texData.get(dataId);
const {values, shape, slice, dtype, complexTensors, isPacked} = texData;
if (slice != null) {
let program;
if (isPacked) {
program = new UnaryOpPackedProgram(shape, unary_op.CLONE);
} else {
program = new UnaryOpProgram(shape, unary_op.CLONE);
}
const res =
this.runWebGLProgram(program, [{dataId, shape, dtype}], dtype);
const data = this.read(res.dataId);
this.disposeData(res.dataId);
return data;
}
if (values != null) {
return this.convertAndCacheOnCPU(dataId);
}
if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&
env().getNumber('WEBGL_VERSION') === 2) {
throw new Error(
`tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` +
`WEBGL_VERSION=2 not yet supported.`);
}
let buffer = null;
let tmpDownloadTarget: TensorInfo;
if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) {
// Possibly copy the texture into a buffer before inserting a fence.
tmpDownloadTarget = this.decode(dataId);
const tmpData = this.texData.get(tmpDownloadTarget.dataId);
buffer = this.gpgpu.createBufferFromTexture(
tmpData.texture, ...tex_util.getDenseTexShape(shape));
}
this.pendingRead.set(dataId, []);
if (dtype !== 'complex64') {
// Create a fence and wait for it to resolve.
await this.gpgpu.createAndWaitForFence();
}
// Download the values from the GPU.
let vals: Float32Array;
if (dtype === 'complex64') {
const ps = await Promise.all(
[complexTensors.real.data(), complexTensors.imag.data()]);
const realValues = ps[0];
const imagValues = ps[1];
vals = mergeRealAndImagArrays(
realValues as Float32Array, imagValues as Float32Array);
} else if (buffer == null) {
vals = this.getValuesFromTexture(dataId);
} else {
const size = util.sizeFromShape(shape);
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
}
if (tmpDownloadTarget != null) {
this.disposeData(tmpDownloadTarget.dataId);
}
const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
const subscribers = this.pendingRead.get(dataId);
this.pendingRead.delete(dataId);
// Notify all pending reads.
subscribers.forEach(resolve => resolve(dTypeVals));
if (this.pendingDisposal.has(dataId)) {
this.pendingDisposal.delete(dataId);
this.disposeData(dataId);
this.pendingDeletes--;
}
return dTypeVals;
}
private checkNumericalProblems(values: BackendValues): void {
if (values == null) {
return;
}
for (let i = 0; i < values.length; i++) {
const num = values[i] as number;
if (!webgl_util.canBeRepresented(num)) {
if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
throw Error(
`The value ${num} cannot be represented with your ` +
`current settings. Consider enabling float32 rendering: ` +
`'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`);
}
throw Error(`The value ${num} cannot be represented on this device.`);
}
}
}
private getValuesFromTexture(dataId: DataId): Float32Array {
const {shape, dtype, isPacked} = this.texData.get(dataId);
const size = util.sizeFromShape(shape);
if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
const tmpTarget = this.decode(dataId);
const tmpData = this.texData.get(tmpTarget.dataId);
const vals = this.gpgpu
.downloadMatrixFromPackedTexture(
tmpData.texture, ...tex_util.getDenseTexShape(shape))
.subarray(0, size);
this.disposeData(tmpTarget.dataId);
return vals;
}
const shouldUsePackedProgram =
env().getBool('WEBGL_PACK') && isPacked === true;
const outputShape =
shouldUsePackedProgram ? webgl_util.getShapeAs3D(shape) : shape;
const program = shouldUsePackedProgram ?
new EncodeFloatPackedProgram(outputShape as [number, number, number]) :
new EncodeFloatProgram(outputShape);
const output = this.runWebGLProgram(
program, [{shape: outputShape, dtype, dataId}], 'float32');
const tmpData = this.texData.get(output.dataId);
const vals =
this.gpgpu
.downloadByteEncodedFloatMatrixFromOutputTexture(
tmpData.texture, tmpData.texShape[0], tmpData.texShape[1])
.subarray(0, size);
this.disposeData(output.dataId);
return vals;
}
async time(f: () => void): Promise<WebGLTimingInfo> {
const oldActiveTimers = this.activeTimers;
const newActiveTimers: TimerNode[] = [];
let outerMostTime = false;
if (this.programTimersStack == null) {
this.programTimersStack = newActiveTimers;
outerMostTime = true;
} else {
this.activeTimers.push(newActiveTimers);
}
this.activeTimers = newActiveTimers;
f();
// needing to split these up because util.flatten only accepts certain types
const flattenedActiveTimerQueries =
util.flatten(this.activeTimers.map((d: KernelInfo) => d.query))
.filter(d => d != null);
const flattenedActiveTimerNames =
util.flatten(this.activeTimers.map((d: KernelInfo) => d.name))
.filter(d => d != null);
this.activeTimers = oldActiveTimers;
if (outerMostTime) {
this.programTimersStack = null;
}
const res: WebGLTimingInfo = {
uploadWaitMs: this.uploadWaitMs,
downloadWaitMs: this.downloadWaitMs,
kernelMs: null,
wallMs: null // will be filled by the engine
};
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
const kernelMs = await Promise.all(flattenedActiveTimerQueries);
res['kernelMs'] = util.sum(kernelMs);
res['getExtraProfileInfo'] = () =>
kernelMs.map((d, i) => ({name: flattenedActiveTimerNames[i], ms: d}))
.map(d => `${d.name}: ${d.ms}`)
.join(', ');
} else {
res['kernelMs'] = {
error: 'WebGL query timers are not supported in this environment.'
};
}
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
return res;
}
memory(): WebGLMemoryInfo {
return {unreliable: false, numBytesInGPU: this.numBytesInGPU} as
WebGLMemoryInfo;
}
private startTimer(): WebGLQuery|CPUTimerQuery {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
return this.gpgpu.beginQuery();
}
return {startMs: util.now(), endMs: null};
}
private endTimer(query: WebGLQuery|CPUTimerQuery): WebGLQuery|CPUTimerQuery {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
this.gpgpu.endQuery();
return query;
}
(query as CPUTimerQuery).endMs = util.now();
return query;
}
private async getQueryTime(query: WebGLQuery|CPUTimerQuery): Promise<number> {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
return this.gpgpu.waitForQueryAndGetTime(query as WebGLQuery);
}
const timerQuery = query as CPUTimerQuery;
return timerQuery.endMs - timerQuery.startMs;
}
private pendingDeletes = 0;
disposeData(dataId: DataId): void {
if (this.pendingDisposal.has(dataId)) {
return;
}
if (this.pendingRead.has(dataId)) {
this.pendingDisposal.add(dataId);
this.pendingDeletes++;
return;
}
// No-op if already disposed.
if (!this.texData.has(dataId)) {
return;
}
this.releaseGPUData(dataId);
const {complexTensors} = this.texData.get(dataId);
if (complexTensors != null) {
complexTensors.real.dispose();
complexTensors.imag.dispose();
}
this.texData.delete(dataId);
}
private releaseGPUData(dataId: DataId): void {
const {texture, dtype, texShape, usage, isPacked, slice} =
this.texData.get(dataId);
const key = slice && slice.origDataId || dataId;
const refCount = this.dataRefCount.get(key);
if (refCount > 1) {
this.dataRefCount.set(key, refCount - 1);
} else {
this.dataRefCount.delete(key);
if (texture != null) {
this.numBytesInGPU -= this.computeBytes(texShape, dtype);
this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
}
}
const texData = this.texData.get(dataId);
texData.texture = null;
texData.texShape = null;
texData.isPacked = false;
texData.slice = null;
}
getTexture(dataId: DataId): WebGLTexture {
this.uploadToGPU(dataId);
return this.texData.get(dataId).texture;
}
/**
* Returns internal information for the specific data bucket. Used in unit
* tests.
*/
getDataInfo(dataId: DataId): TextureData {
return this.texData.get(dataId);
}
private getCPUBackend(): KernelBackend|null {
if (!env().getBool('WEBGL_CPU_FORWARD')) {
return null;
}
if (this.cpuBackend == null) {
this.cpuBackend = ENGINE.findBackend('cpu');
}
return this.cpuBackend;
}
/*
Tests whether all the inputs to an op are small and on the CPU. This heuristic
determines when it would be faster to execute a kernel on the CPU. WebGL
kernels opt into running this check and forwarding when appropriate.
TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more
sustainable strategy for optimizing backend execution of ops.
*/
private shouldExecuteOnCPU(
inputs: Tensor[], sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean {
return this.getCPUBackend() != null &&
inputs.every(
input => this.texData.get(input.dataId).texture == null &&
input.size < sizeThreshold);
}
getGPGPUContext(): GPGPUContext {
return this.gpgpu;
}
complex<T extends Tensor>(real: T, imag: T): T {
const result = this.makeOutput(real.shape, 'complex64');
const resultData = this.texData.get(result.dataId);
// The backend owns the reference to the underlying real and imaginary
// clones. These will explicitly get disposed when the complex tensor is
// disposed.
resultData.complexTensors = {
real: ENGINE.keep(real.clone()),
imag: ENGINE.keep(imag.clone())
};
return result as T;
}
real<T extends Tensor>(input: T): T {
const resultData = this.texData.get(input.dataId);
return resultData.complexTensors.real.clone() as T;
}
imag<T extends Tensor>(input: T): T {
const resultData = this.texData.get(input.dataId);
return resultData.complexTensors.imag.clone() as T;
}
slice<T extends Tensor>(x: T, begin: number[], size: number[]): T {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.slice(x, begin, size);
}
// Short-circuit computation if the slice is zero-sized.
if (util.sizeFromShape(size) === 0) {
return tensor([], size, x.dtype) as T;
}
const {isPacked} = this.texData.get(x.dataId);
const isContinous = slice_util.isSliceContinous(x.shape, begin, size);
if (isPacked || !isContinous) {
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new SlicePackedProgram(size) :
new SliceProgram(size);
const customSetup = program.getCustomSetupFunc(begin);
return this.compileAndRun(program, [x], null, customSetup);
}
this.uploadToGPU(x.dataId);
return this.shallowSlice(x, begin, size) as T;
}
private shallowSlice(x: Tensor, begin: number[], size: number[]): Tensor {
const xTexData = this.texData.get(x.dataId);
const t = this.makeOutput(size, x.dtype);
const newTexData = this.texData.get(t.dataId);
// Copy texture data from the original tensor.
Object.assign(newTexData, xTexData);
newTexData.shape = size;
newTexData.dtype = x.dtype;
let flatOffset = slice_util.computeFlatOffset(begin, x.strides);
if (xTexData.slice) {
// We are slicing an already sliced tensor, so we have to accumulate
// the offset.
flatOffset += xTexData.slice.flatOffset;
}
newTexData.slice = {
flatOffset,
// Point to the original dataId, which is used to do ref counting.
origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
};
// Increase the ref count for that data bucket.
const refCount = this.dataRefCount.get(newTexData.slice.origDataId) || 1;
this.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
return t;
}
stridedSlice<T extends Tensor>(
x: T, begin: number[], end: number[], strides: number[]): T {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.stridedSlice(x, begin, end, strides);
}
const outShape = slice_util.computeOutShape(begin, end, strides);
if (outShape.some(axis => axis === 0)) {
return tensor([], outShape) as T;
}
const program = new StridedSliceProgram(begin, strides, outShape);
return this.compileAndRun(program, [x]);
}
reverse<T extends Tensor>(x: T, axis: number[]): T {
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new ReversePackedProgram(x.shape, axis) :
new ReverseProgram(x.shape, axis);
return this.compileAndRun(program, [x]);
}
concat(tensors: Tensor[], axis: number): Tensor {
if (tensors[0].dtype === 'complex64') {
const reals = tensors.map((t) => real(t));
const imags = tensors.map((t) => imag(t));
return complex(this.concat(reals, axis), this.concat(imags, axis));
}
if (this.shouldExecuteOnCPU(tensors)) {
return this.cpuBackend.concat(tensors, axis);
}
if (tensors.length === 1) {
return tensors[0];
}
if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
const midIndex = Math.floor(tensors.length / 2);
const leftSide = this.concat(tensors.slice(0, midIndex), axis);
const rightSide = this.concat(tensors.slice(midIndex), axis);
return this.concat([leftSide, rightSide], axis);
}
if (env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && tensors[0].rank > 1) {
const program = new ConcatPackedProgram(tensors.map(t => t.shape), axis);
return this.compileAndRun(program, tensors);
}
// Any concat of n-dimensional tensors across any axis can be reduced to
// a concatenation of two-dimensional tensors across the axis 1 by first
// partitioning the axes of the original tensors into those less than the
// axis to be concatenated and the rest. Then reshape the tensors
// into a two-dimensional tensor by collapsing these two sets of axes and
// concatenate the resulting matrices across the axis 1, finally reshaping
// the result to have the proper shape.
const outShape = computeOutShape(tensors.map(t => t.shape), axis);
const tensors2D =
tensors.map(t => t.as2D(-1, sizeFromShape(t.shape.slice(axis))));
const program = new ConcatProgram(tensors2D.map(t => t.shape));
const res: Tensor = this.compileAndRun(program, tensors2D);
return res.reshape(outShape);
}
neg<T extends Tensor>(x: T): T {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.neg(x);
}
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
return this.packedUnaryOp(x, unary_op.NEG, x.dtype) as T;
}
const program = new UnaryOpProgram(x.shape, unary_op.NEG);
return this.compileAndRun(program, [x]);
}
batchMatMul(
a: Tensor3D, b: Tensor3D, transposeA: boolean,
transposeB: boolean): Tensor3D {
const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
const sharedDim = transposeA ? a.shape[1] : a.shape[2];
const [batch, , ] = a.shape;
// Since the matrices are vectors, it is faster to call mul().sum()
// because sum() is O(sqrt(N)) due to divide-and-conquer.
if ((outerShapeA === 1 || outerShapeB === 1) &&
sharedDim > MATMUL_SHARED_DIM_THRESHOLD) {
if (transposeA) {
a = a.transpose([0, 2, 1]);
}
if (transposeB) {
b = b.transpose([0, 2, 1]);
}
const a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1);
const axis = outerShapeB === 1 ? 2 : 1;
const b3D = outerShapeB === 1 ? b.as3D(batch, 1, sharedDim) : b;
return this.multiply(a3D, b3D).sum(axis, true /* keepDims */);
}
const dtype = upcastType(a.dtype, b.dtype);
const program = new MatMulPackedProgram(
a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB);
return this.compileAndRun<Tensor3D>(program, [a, b], dtype);
}
fusedBatchMatMul(
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
FusedBatchMatMulConfig): Tensor3D {
const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
const [batch, , ] = a.shape;
const dtype = upcastType(a.dtype, b.dtype);
const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
const fusedActivation =
activation ? mapActivationToShaderProgram(activation, true) : null;
const program = new MatMulPackedProgram(
a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB,
hasBias, fusedActivation, hasPreluActivationWeights);
const inputs: TensorInfo[] = [a, b];
if (bias) {
inputs.push(bias);
}
if (preluActivationWeights) {
inputs.push(preluActivationWeights);
}
return this.compileAndRun<Tensor3D>(program, inputs, dtype);
}
multiply(a: Tensor, b: Tensor): Tensor {
if (a.dtype === 'complex64') {
const aData = this.texData.get(a.dataId);
const bData = this.texData.get(b.dataId);
const realProgram = new BinaryOpComplexProgram(
binaryop_complex_gpu.COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
const imagProgram = new BinaryOpComplexProgram(
binaryop_complex_gpu.COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
const inputs = [
this.makeComplexComponentTensorInfo(a, aData.complexTensors.real),
this.makeComplexComponentTensorInfo(a, aData.complexTensors.imag),
this.makeComplexComponentTensorInfo(b, bData.complexTensors.real),
this.makeComplexComponentTensorInfo(b, bData.complexTensors.imag)
];
const real = this.compileAndRun<Tensor>(realProgram, inputs);
const imag = this.compileAndRun<Tensor>(imagProgram, inputs);
const complex = this.complex(real, imag);
real.dispose();
imag.dispose();
return complex;
}
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.multiply(a, b);
}
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_gpu.MUL, a.dtype);
}
const program = new BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape);
return this.compileAndRun(program, [a, b], a.dtype);
}
batchNormalization(
x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D,
varianceEpsilon: number, scale?: Tensor4D|Tensor1D,
offset?: Tensor4D|Tensor1D): Tensor4D {
const inputs = [x, mean, variance];
let offsetShape = null;
if (offset != null) {
offsetShape = offset.shape;
inputs.push(offset);
}
let scaleShape = null;
if (scale != null) {
scaleShape = scale.shape;
inputs.push(scale);
}
if (env().getBool('WEBGL_PACK_NORMALIZATION')) {
const batchNormPackedProgram = new BatchNormPackedProgram(
x.shape, mean.shape, variance.shape, offsetShape, scaleShape,
varianceEpsilon);
return this.compileAndRun<Tensor4D>(batchNormPackedProgram, inputs);
}
const batchNormProgram = new BatchNormProgram(
x.shape, mean.shape, variance.shape, offsetShape, scaleShape,
varianceEpsilon);
return this.compileAndRun(batchNormProgram, inputs);
}
localResponseNormalization4D(
x: Tensor4D, radius: number, bias: number, alpha: number,
beta: number): Tensor4D {
const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
new LRNPackedProgram(x.shape, radius, bias, alpha, beta) :
new LRNProgram(x.shape, radius, bias, alpha, beta);
return this.compileAndRun(program, [x]);
}
LRNGrad(
dy: Tensor4D, inputImage: Tensor4D, outputImage: Tensor4D,
depthRadius: number, bias: number, alpha: number,
beta: number): Tensor4D {
const program =
new LRNGradProgram(inputImage.shape, depthRadius, bias, alpha, beta);
return this.compileAndRun(program, [inputImage, outputImage, dy]);
}
tile<T extends Tensor>(x: T, reps: number[]): T {
if (x.dtype === 'string') {
const data = this.readSync(x.dataId) as Uint8Array[];
const decodedData = data.map(d => util.decodeString(d));
const buf = buffer(x.shape, x.dtype, decodedData);
return tile(buf, reps) as T;
}
const program = new TileProgram(x.shape, reps);
return this.compileAndRun(program, [x]);
}
pad<T extends Tensor>(
x: T, paddings: Array<[number, number]>, constantValue: number): T {
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new PadPackedProgram(x.shape, paddings, constantValue) :
new PadProgram(x.shape, paddings, constantValue);
return this.compileAndRun(program, [x]);
}
transpose<T extends Tensor>(x: T, perm: number[]): T {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.transpose(x, perm);
}
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new TransposePackedProgram(x.shape, perm) :
new TransposeProgram(x.shape, perm);
return this.compileAndRun(program, [x]);
}
gather<T extends Tensor>(x: T, indices: Tensor1D, axis: number): T {
if (this.shouldExecuteOnCPU([x, indices])) {
return this.cpuBackend.gather(x, indices, axis);
}
const program = new GatherProgram(x.shape, indices.size, axis);
return this.compileAndRun(program, [x, indices]);
}
batchToSpaceND<T extends Tensor>(
x: T, blockShape: number[], crops: number[][]): T {
util.assert(
x.rank <= 4,
() => 'batchToSpaceND for rank > 4 with a WebGL backend not ' +
'implemented yet');
const prod = blockShape.reduce((a, b) => a * b);
const reshaped = array_ops_util.getReshaped(x.shape, blockShape, prod);
const permuted =
array_ops_util.getPermuted(reshaped.length, blockShape.length);
const reshapedPermuted =
array_ops_util.getReshapedPermuted(x.shape, blockShape, prod);
const sliceBeginCoords =
array_ops_util.getSliceBeginCoords(crops, blockShape.length);
const sliceSize =
array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length);
return x.reshape(reshaped)
.transpose(permuted)
.reshape(reshapedPermuted)
.slice(sliceBeginCoords, sliceSize) as T;
}
spaceToBatchND<T extends Tensor>(
x: T, blockShape: number[], paddings: Array<[number, number]>): T {
util.assert(
x.rank <= 4,
() => 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
'implemented yet');
const prod = blockShape.reduce((a, b) => a * b);
const completePaddings: Array<[number, number]> = [[0, 0]];
completePaddings.push(...paddings);
for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}
const paddedX = x.pad(completePaddings);
const reshapedPaddedShape =
array_ops_util.getReshaped(paddedX.shape, blockShape, prod, false);
const permutedReshapedPaddedPermutation = array_ops_util.getPermuted(
reshapedPaddedShape.length, blockShape.length, false);
const flattenShape = array_ops_util.getReshapedPermuted(
paddedX.shape, blockShape, prod, false);
return paddedX.reshape(reshapedPaddedShape)
.transpose(permutedReshapedPaddedPermutation)
.reshape(flattenShape) as T;
}
private reduce(
x: Tensor2D, reduceType: 'all'|'any'|'max'|'min'|'sum'|'prod',
dtype: DataType): Tensor2D {
const batchSize = x.shape[0];
const inSize = x.shape[1];
const windowSize = reduce_util.computeOptimalWindowSize(inSize);
const reduceInfo = {windowSize, inSize, batchSize};
const program = new ReduceProgram(reduceInfo, reduceType);
const output = this.compileAndRun<Tensor2D>(program, [x], dtype);
// No need to run another GPGPU program.
if (output.shape[1] === 1) {
return output;
}
return this.reduce(output, reduceType, dtype);
}
private argReduce(
x: Tensor2D, reduceType: 'max'|'min',
bestIndicesA: Tensor2D = null): Tensor2D {
let batchSize = x.shape[0];
let inSize = x.shape[1];
if (bestIndicesA != null) {
batchSize = bestIndicesA.shape[0];
inSize = bestIndicesA.shape[1];
}
const windowSize = reduce_util.computeOptimalWindowSize(inSize);
const reduceInfo = {windowSize, inSize, batchSize};
const program =
new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
const inputs = [x];
if (bestIndicesA != null) {
inputs.push(bestIndicesA);
}
const output = this.compileAndRun<Tensor2D>(program, inputs, 'int32');
// No need to run another GPGPU program.
if (output.shape[1] === 1) {
return output;
}
return this.argReduce(x, reduceType, output);
}
private argReducePacked(
x: Tensor, reduceType: 'max'|'min', bestIndicesA: Tensor = null): Tensor {
const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
const inSize = inShape[inShape.length - 1];
const windowSize = reduce_util.computeOptimalWindowSize(inSize);
const program = new ArgMinMaxPackedProgram(
inShape, windowSize, reduceType, bestIndicesA == null);
const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
const output = this.compileAndRun<Tensor>(program, inputs, 'int32');
if (output.rank === x.rank) {
return this.argReducePacked(x, reduceType, output);
}
return output;
}
sum(x: Tensor, axes: number[]): Tensor {
axis_util.assertAxesAreInnerMostDims('sum', axes, x.rank);
const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);
const inSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, inSize);
const outputDType = sumOutType(x.dtype);
return this.reduce(a2D, 'sum', outputDType).reshape(outShape);
}
prod(x: Tensor, axes: number[]): Tensor {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.prod(x, axes);
}
const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);
const inSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, inSize);
const outputDType = sumOutType(x.dtype);
return this.reduce(a2D, 'prod', outputDType).reshape(outShape);
}
unsortedSegmentSum<T extends Tensor>(
x: T, segmentIds: Tensor1D, numSegments: number): Tensor {
let axis = 0;
const permutation = axis_util.getAxesPermutation([axis], x.rank);
let permutedX = x;
if (permutation != null) {
permutedX = x.transpose(permutation);
axis = axis_util.getInnerMostAxes(1, x.rank)[0];
}
const outShape =
segment_util.computeOutShape(permutedX.shape, axis, numSegments);
const inSize = util.sizeFromShape([permutedX.shape[axis]]);
const a2D = permutedX.as2D(-1, inSize);
const outputDType = sumOutType(x.dtype);
let result =
this.segOpCompute(
a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments)
.reshape(outShape);
if (permutation != null) {
result = result.transpose(axis_util.getUndoAxesPermutation(permutation));
}
return result;
}
private segOpCompute(
x: Tensor2D, segOpType: 'unsortedSegmentSum', segmentIds: Tensor1D,
dtype: DataType, numSegments: number): Tensor2D {
const batchSize = x.shape[0];
const inSize = x.shape[1];
const windowSize =
segment_util.segOpComputeOptimalWindowSize(inSize, numSegments);
const segOpInfo = {windowSize, inSize, batchSize, numSegments};
const program = new SegmentOpProgram(segOpInfo, segOpType);
const output =
this.compileAndRun<Tensor2D>(program, [x, segmentIds], dtype);
// No need to run another GPGPU program.
if (output.shape[1] === numSegments) {
return output;
}
segmentIds = range(0, numSegments).tile([inSize / windowSize]);
return this.segOpCompute(output, segOpType, segmentIds, dtype, numSegments);
}
private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'):
Tensor {
const axes = [axis];
axis_util.assertAxesAreInnerMostDims(
'arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes,
x.rank);
if (!env().getBool('WEBGL_PACK_REDUCE') || x.rank <= 2) {
const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);
const inSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, inSize);
return this.argReduce(a2D, reduceType).reshape(outShape);
}
return this.argReducePacked(x, reduceType);
}
argMin(x: Tensor, axis: number): Tensor {
return this.argMinMaxReduce(x, axis, 'min');
}
argMax(x: Tensor, axis: number): Tensor {
return this.argMinMaxReduce(x, axis, 'max');
}
cumsum(x: Tensor, axis: number, exclusive: boolean, reverse: boolean):
Tensor {
if (axis !== x.rank - 1) {
throw new Error(
`WebGL cumsum shader expects an inner-most axis=${x.rank - 1} ` +
`but got axis=${axis}`);
}
const program = new CumSumProgram(x.shape, exclusive, reverse);
return this.compileAndRun(program, [x]);
}
equal(a: Tensor, b: Tensor): Tensor {
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_packed_gpu.EQUAL, 'bool');
}
const program = new BinaryOpProgram(binaryop_gpu.EQUAL, a.shape, b.shape);
return this.compileAndRun(program, [a, b], 'bool');
}
notEqual(a: Tensor, b: Tensor): Tensor {
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_packed_gpu.NOT_EQUAL, 'bool');
}
const program =
new BinaryOpProgram(binaryop_gpu.NOT_EQUAL, a.shape, b.shape);
return this.compileAndRun(program, [a, b], 'bool');
}
less(a: Tensor, b: Tensor): Tensor {
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.less(a, b);
}
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS, 'bool');
}
const program = new BinaryOpProgram(binaryop_gpu.LESS, a.shape, b.shape);
return this.compileAndRun(program, [a, b], 'bool');
}
lessEqual(a: Tensor, b: Tensor): Tensor {
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS_EQUAL, 'bool');
}
const program =
new BinaryOpProgram(binaryop_gpu.LESS_EQUAL, a.shape, b.shape);
return this.compileAndRun(program, [a, b], 'bool');
}
greater(a: Tensor, b: Tensor): Tensor {
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.greater(a, b);
}
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_packed_gpu.GREATER, 'bool');
}
const program = new BinaryOpProgram(binaryop_gpu.GREATER, a.shape, b.shape);
return this.compileAndRun(program, [a, b], 'bool');
}
greaterEqual(a: Tensor, b: Tensor): Tensor {
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(
a, b, binaryop_packed_gpu.GREATER_EQUAL, 'bool');
}
const program =
new BinaryOpProgram(binaryop_gpu.GREATER_EQUAL, a.shape, b.shape);
return this.compileAndRun(program, [a, b], 'bool');
}
logicalNot<T extends Tensor>(x: T): T {
const program = new UnaryOpProgram(x.shape, unary_op.LOGICAL_NOT);
return this.compileAndRun(program, [x]);
}
logicalAnd(a: Tensor, b: Tensor): Tensor {
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_AND, 'bool');
}
const program =
new BinaryOpProgram(binaryop_gpu.LOGICAL_AND, a.shape, b.shape);
return this.compileAndRun(program, [a, b], 'bool');
}
logicalOr(a: Tensor, b: Tensor): Tensor {
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_OR, 'bool');
}
const program =
new BinaryOpProgram(binaryop_gpu.LOGICAL_OR, a.shape, b.shape);
return this.compileAndRun(program, [a, b], 'bool');
}
select(condition: Tensor, a: Tensor, b: Tensor): Tensor {
const program = new SelectProgram(condition.rank, a.shape, a.rank);
return this.compileAndRun(
program, [condition, a, b], upcastType(a.dtype, b.dtype));
}
where(condition: Tensor): Tensor2D {
warn(
'tf.where() in webgl locks the UI thread. ' +
'Call tf.whereAsync() instead');
const condVals = condition.dataSync();
return whereImpl(condition.shape, condVals);
}
topk<T extends Tensor>(x: T, k: number, sorted: boolean): [T, T] {
const xVals = x.dataSync();
return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);
}
min(x: Tensor, axes: number[]): Tensor {
axis_util.assertAxesAreInnerMostDims('min', axes, x.rank);
const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);
const inSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, inSize);
return this.reduce(a2D, 'min', a2D.dtype).reshape(outShape);
}
minimum(a: Tensor, b: Tensor): Tensor {
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.minimum(a, b);
}
const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
new BinaryOpPackedProgram(binaryop_packed_gpu.MIN, a.shape, b.shape) :
new BinaryOpProgram(binaryop_gpu.MIN, a.shape, b.shape);
return this.compileAndRun(program, [a, b]);
}
mod(a: Tensor, b: Tensor): Tensor {
const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
new BinaryOpPackedProgram(binaryop_packed_gpu.MOD, a.shape, b.shape) :
new BinaryOpProgram(binaryop_gpu.MOD, a.shape, b.shape);
return this.compileAndRun(program, [a, b]);
}
max(x: Tensor, axes: number[]): Tensor {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.max(x, axes);
}
axis_util.assertAxesAreInnerMostDims('max', axes, x.rank);
const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);
const inSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, inSize);
return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape);
}
maximum(a: Tensor, b: Tensor): Tensor {
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.maximum(a, b);
}
const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
new BinaryOpPackedProgram(binaryop_packed_gpu.MAX, a.shape, b.shape) :
new BinaryOpProgram(binaryop_gpu.MAX, a.shape, b.shape);
return