@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
980 lines • 120 kB
JavaScript
"use strict";
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var __extends = (this && this.__extends) || (function () {
var extendStatics = function (d, b) {
extendStatics = Object.setPrototypeOf ||
({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; };
return extendStatics(d, b);
};
return function (d, b) {
extendStatics(d, b);
function __() { this.constructor = d; }
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
};
})();
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
var __generator = (this && this.__generator) || function (thisArg, body) {
var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
function verb(n) { return function (v) { return step([n, v]); }; }
function step(op) {
if (f) throw new TypeError("Generator is already executing.");
while (_) try {
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
if (y = 0, t) op = [op[0] & 2, t.value];
switch (op[0]) {
case 0: case 1: t = op; break;
case 4: _.label++; return { value: op[1], done: false };
case 5: _.label++; y = op[1]; op = [0]; continue;
case 7: op = _.ops.pop(); _.trys.pop(); continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
if (t[2]) _.ops.pop();
_.trys.pop(); continue;
}
op = body.call(thisArg, _);
} catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
}
};
Object.defineProperty(exports, "__esModule", { value: true });
// Import webgl flags.
require("./flags_webgl");
var device_util = require("../../device_util");
var engine_1 = require("../../engine");
var environment_1 = require("../../environment");
var globals_1 = require("../../globals");
var log_1 = require("../../log");
var array_ops_1 = require("../../ops/array_ops");
var array_ops_util = require("../../ops/array_ops_util");
var axis_util = require("../../ops/axis_util");
var complex_ops_1 = require("../../ops/complex_ops");
var concat_util_1 = require("../../ops/concat_util");
var gather_nd_util = require("../../ops/gather_nd_util");
var reduce_util = require("../../ops/reduce_util");
var scatter_nd_util = require("../../ops/scatter_nd_util");
var segment_util = require("../../ops/segment_util");
var slice_util = require("../../ops/slice_util");
var softmax_1 = require("../../ops/softmax");
var tensor_ops_1 = require("../../ops/tensor_ops");
var types_1 = require("../../types");
var util = require("../../util");
var util_1 = require("../../util");
var backend_1 = require("../backend");
var backend_util = require("../backend_util");
var complex_util_1 = require("../complex_util");
var non_max_suppression_impl_1 = require("../non_max_suppression_impl");
var split_shared_1 = require("../split_shared");
var tile_impl_1 = require("../tile_impl");
var topk_impl_1 = require("../topk_impl");
var where_impl_1 = require("../where_impl");
var addn_gpu_1 = require("./addn_gpu");
var addn_packed_gpu_1 = require("./addn_packed_gpu");
var argminmax_gpu_1 = require("./argminmax_gpu");
var argminmax_packed_gpu_1 = require("./argminmax_packed_gpu");
var avg_pool_backprop_gpu_1 = require("./avg_pool_backprop_gpu");
var batchnorm_gpu_1 = require("./batchnorm_gpu");
var batchnorm_packed_gpu_1 = require("./batchnorm_packed_gpu");
var binaryop_complex_gpu = require("./binaryop_complex_gpu");
var binaryop_complex_gpu_1 = require("./binaryop_complex_gpu");
var binaryop_gpu = require("./binaryop_gpu");
var binaryop_gpu_1 = require("./binaryop_gpu");
var binaryop_packed_gpu = require("./binaryop_packed_gpu");
var binaryop_packed_gpu_1 = require("./binaryop_packed_gpu");
var canvas_util_1 = require("./canvas_util");
var clip_gpu_1 = require("./clip_gpu");
var clip_packed_gpu_1 = require("./clip_packed_gpu");
var complex_abs_gpu_1 = require("./complex_abs_gpu");
var concat_gpu_1 = require("./concat_gpu");
var concat_packed_gpu_1 = require("./concat_packed_gpu");
var conv_backprop_gpu_1 = require("./conv_backprop_gpu");
var conv_backprop_gpu_depthwise_1 = require("./conv_backprop_gpu_depthwise");
var conv_gpu_1 = require("./conv_gpu");
var conv_gpu_depthwise_1 = require("./conv_gpu_depthwise");
var conv_packed_gpu_depthwise_1 = require("./conv_packed_gpu_depthwise");
var crop_and_resize_gpu_1 = require("./crop_and_resize_gpu");
var cumsum_gpu_1 = require("./cumsum_gpu");
var decode_matrix_gpu_1 = require("./decode_matrix_gpu");
var decode_matrix_packed_gpu_1 = require("./decode_matrix_packed_gpu");
var depth_to_space_gpu_1 = require("./depth_to_space_gpu");
var diag_gpu_1 = require("./diag_gpu");
var encode_float_gpu_1 = require("./encode_float_gpu");
var encode_float_packed_gpu_1 = require("./encode_float_packed_gpu");
var encode_matrix_gpu_1 = require("./encode_matrix_gpu");
var encode_matrix_packed_gpu_1 = require("./encode_matrix_packed_gpu");
var fft_gpu = require("./fft_gpu");
var fft_gpu_1 = require("./fft_gpu");
var fill_gpu_1 = require("./fill_gpu");
var gather_gpu_1 = require("./gather_gpu");
var gather_nd_gpu_1 = require("./gather_nd_gpu");
var gpgpu_context_1 = require("./gpgpu_context");
var gpgpu_math = require("./gpgpu_math");
var im2col_packed_gpu_1 = require("./im2col_packed_gpu");
var lrn_gpu_1 = require("./lrn_gpu");
var lrn_grad_gpu_1 = require("./lrn_grad_gpu");
var lrn_packed_gpu_1 = require("./lrn_packed_gpu");
var max_pool_backprop_gpu_1 = require("./max_pool_backprop_gpu");
var mulmat_packed_gpu_1 = require("./mulmat_packed_gpu");
var multinomial_gpu_1 = require("./multinomial_gpu");
var onehot_gpu_1 = require("./onehot_gpu");
var pack_gpu_1 = require("./pack_gpu");
var pad_gpu_1 = require("./pad_gpu");
var pad_packed_gpu_1 = require("./pad_packed_gpu");
var pool_gpu_1 = require("./pool_gpu");
var reduce_gpu_1 = require("./reduce_gpu");
var reshape_packed_gpu_1 = require("./reshape_packed_gpu");
var resize_bilinear_backprop_gpu_1 = require("./resize_bilinear_backprop_gpu");
var resize_bilinear_gpu_1 = require("./resize_bilinear_gpu");
var resize_bilinear_packed_gpu_1 = require("./resize_bilinear_packed_gpu");
var resize_nearest_neighbor_backprop_gpu_1 = require("./resize_nearest_neighbor_backprop_gpu");
var resize_nearest_neighbor_gpu_1 = require("./resize_nearest_neighbor_gpu");
var reverse_gpu_1 = require("./reverse_gpu");
var reverse_packed_gpu_1 = require("./reverse_packed_gpu");
var scatter_gpu_1 = require("./scatter_gpu");
var segment_gpu_1 = require("./segment_gpu");
var select_gpu_1 = require("./select_gpu");
var slice_gpu_1 = require("./slice_gpu");
var slice_packed_gpu_1 = require("./slice_packed_gpu");
var strided_slice_gpu_1 = require("./strided_slice_gpu");
var tex_util = require("./tex_util");
var tex_util_1 = require("./tex_util");
var texture_manager_1 = require("./texture_manager");
var tile_gpu_1 = require("./tile_gpu");
var transpose_gpu_1 = require("./transpose_gpu");
var transpose_packed_gpu_1 = require("./transpose_packed_gpu");
var unary_op = require("./unaryop_gpu");
var unaryop_gpu_1 = require("./unaryop_gpu");
var unary_packed_op = require("./unaryop_packed_gpu");
var unaryop_packed_gpu_1 = require("./unaryop_packed_gpu");
var unpack_gpu_1 = require("./unpack_gpu");
var webgl_util = require("./webgl_util");
var binaryCaches = {};
function getBinaryCache(webGLVersion) {
if (webGLVersion in binaryCaches) {
return binaryCaches[webGLVersion];
}
binaryCaches[webGLVersion] = {};
return binaryCaches[webGLVersion];
}
exports.getBinaryCache = getBinaryCache;
function mapActivationToShaderProgram(activation, packed) {
if (packed === void 0) { packed = false; }
if (activation === 'linear') {
if (packed) {
return unary_packed_op.LINEAR;
}
return unary_op.LINEAR;
}
else if (activation === 'relu') {
if (packed) {
return unary_packed_op.RELU;
}
return unary_op.RELU;
}
else if (activation === 'elu') {
if (packed) {
return unary_packed_op.ELU;
}
return unary_op.ELU;
}
else if (activation === 'relu6') {
if (packed) {
return unary_packed_op.RELU6;
}
return unary_op.RELU6;
}
else if (activation === 'prelu') {
if (packed) {
return binaryop_packed_gpu.PRELU;
}
return binaryop_gpu.PRELU;
}
throw new Error("Activation " + activation + " has not been implemented for the WebGL backend.");
}
// Empirically determined constant used to determine size threshold for handing
// off execution to the CPU.
var CPU_HANDOFF_SIZE_THRESHOLD = 128;
// Empirically determined constant used to decide the number of MB on GPU
// before we warn about high memory use. The MB are this constant * screen area
// * dpi / 1024 / 1024.
var BEFORE_PAGING_CONSTANT = 600;
function numMBBeforeWarning() {
if (environment_1.env().global.screen == null) {
return 1024; // 1 GB.
}
return (environment_1.env().global.screen.height * environment_1.env().global.screen.width *
window.devicePixelRatio) *
BEFORE_PAGING_CONSTANT / 1024 / 1024;
}
// Empirically determined minimal shared dimension in matmul before we forward
// to a.mul(b).sum() in order to take advantage of GPU parallelism. See
// https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.
exports.MATMUL_SHARED_DIM_THRESHOLD = 1000;
var MathBackendWebGL = /** @class */ (function (_super) {
__extends(MathBackendWebGL, _super);
function MathBackendWebGL(gpgpu) {
var _this = _super.call(this) || this;
// Maps data ids that have a pending read operation, to list of subscribers.
_this.pendingRead = new WeakMap();
// List of data ids that are scheduled for disposal, but are waiting on a
// pending read operation.
_this.pendingDisposal = new WeakSet();
// Used to count the number of 'shallow' sliced tensors that point to the
// same data id.
_this.dataRefCount = new WeakMap();
_this.numBytesInGPU = 0;
// Accumulated time spent (including blocking) in uploading data to webgl.
_this.uploadWaitMs = 0;
// Accumulated time spent (including blocking in downloading data from webgl.
_this.downloadWaitMs = 0;
_this.warnedAboutMemory = false;
_this.pendingDeletes = 0;
_this.disposed = false;
if (!environment_1.env().getBool('HAS_WEBGL')) {
throw new Error('WebGL is not supported on this device');
}
if (gpgpu == null) {
var gl = canvas_util_1.getWebGLContext(environment_1.env().getNumber('WEBGL_VERSION'));
_this.binaryCache = getBinaryCache(environment_1.env().getNumber('WEBGL_VERSION'));
_this.gpgpu = new gpgpu_context_1.GPGPUContext(gl);
_this.canvas = gl.canvas;
_this.gpgpuCreatedLocally = true;
}
else {
_this.gpgpu = gpgpu;
_this.binaryCache = {};
_this.gpgpuCreatedLocally = false;
_this.canvas = gpgpu.gl.canvas;
}
_this.textureManager = new texture_manager_1.TextureManager(_this.gpgpu);
_this.numMBBeforeWarning = numMBBeforeWarning();
_this.texData = new backend_1.DataStorage(_this, engine_1.ENGINE);
return _this;
}
MathBackendWebGL.prototype.numDataIds = function () {
return this.texData.numDataIds() +
(this.cpuBackend ? this.cpuBackend.numDataIds() : 0) -
this.pendingDeletes;
};
MathBackendWebGL.prototype.write = function (values, shape, dtype) {
if (environment_1.env().getBool('DEBUG')) {
this.checkNumericalProblems(values);
}
if (dtype === 'complex64' && values != null) {
throw new Error("Cannot write to a complex64 dtype. " +
"Please use tf.complex(real, imag).");
}
var dataId = {};
this.texData.set(dataId, { shape: shape, dtype: dtype, values: values, usage: tex_util_1.TextureUsage.UPLOAD });
return dataId;
};
MathBackendWebGL.prototype.move = function (dataId, values, shape, dtype) {
if (environment_1.env().getBool('DEBUG')) {
this.checkNumericalProblems(values);
}
if (dtype === 'complex64') {
throw new Error("Cannot write to a complex64 dtype. " +
"Please use tf.complex(real, imag).");
}
this.texData.set(dataId, { shape: shape, dtype: dtype, values: values, usage: tex_util_1.TextureUsage.UPLOAD });
};
MathBackendWebGL.prototype.readSync = function (dataId) {
var texData = this.texData.get(dataId);
var values = texData.values, dtype = texData.dtype, complexTensors = texData.complexTensors, slice = texData.slice, shape = texData.shape, isPacked = texData.isPacked;
if (slice != null) {
var program = void 0;
if (isPacked) {
program = new unaryop_packed_gpu_1.UnaryOpPackedProgram(shape, unary_op.CLONE);
}
else {
program = new unaryop_gpu_1.UnaryOpProgram(shape, unary_op.CLONE);
}
var res = this.runWebGLProgram(program, [{ dataId: dataId, shape: shape, dtype: dtype }], dtype);
var data = this.readSync(res.dataId);
this.disposeData(res.dataId);
return data;
}
if (values != null) {
return this.convertAndCacheOnCPU(dataId);
}
if (dtype === 'string') {
return values;
}
var shouldTimeProgram = this.activeTimers != null;
var start;
if (shouldTimeProgram) {
start = util.now();
}
var result;
if (dtype === 'complex64') {
var realValues = complexTensors.real.dataSync();
var imagValues = complexTensors.imag.dataSync();
result = complex_util_1.mergeRealAndImagArrays(realValues, imagValues);
}
else {
result = this.getValuesFromTexture(dataId);
}
if (shouldTimeProgram) {
this.downloadWaitMs += util.now() - start;
}
return this.convertAndCacheOnCPU(dataId, result);
};
MathBackendWebGL.prototype.read = function (dataId) {
return __awaiter(this, void 0, void 0, function () {
var subscribers_1, texData, values, shape, slice, dtype, complexTensors, isPacked, program, res, data, buffer, tmpDownloadTarget, tmpData, vals, ps, realValues, imagValues, size, dTypeVals, subscribers;
var _a;
return __generator(this, function (_b) {
switch (_b.label) {
case 0:
if (this.pendingRead.has(dataId)) {
subscribers_1 = this.pendingRead.get(dataId);
return [2 /*return*/, new Promise(function (resolve) { return subscribers_1.push(resolve); })];
}
texData = this.texData.get(dataId);
values = texData.values, shape = texData.shape, slice = texData.slice, dtype = texData.dtype, complexTensors = texData.complexTensors, isPacked = texData.isPacked;
if (slice != null) {
program = void 0;
if (isPacked) {
program = new unaryop_packed_gpu_1.UnaryOpPackedProgram(shape, unary_op.CLONE);
}
else {
program = new unaryop_gpu_1.UnaryOpProgram(shape, unary_op.CLONE);
}
res = this.runWebGLProgram(program, [{ dataId: dataId, shape: shape, dtype: dtype }], dtype);
data = this.read(res.dataId);
this.disposeData(res.dataId);
return [2 /*return*/, data];
}
if (values != null) {
return [2 /*return*/, this.convertAndCacheOnCPU(dataId)];
}
if (!environment_1.env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&
environment_1.env().getNumber('WEBGL_VERSION') === 2) {
throw new Error("tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and " +
"WEBGL_VERSION=2 not yet supported.");
}
buffer = null;
if (dtype !== 'complex64' && environment_1.env().get('WEBGL_BUFFER_SUPPORTED')) {
// Possibly copy the texture into a buffer before inserting a fence.
tmpDownloadTarget = this.decode(dataId);
tmpData = this.texData.get(tmpDownloadTarget.dataId);
buffer = (_a = this.gpgpu).createBufferFromTexture.apply(_a, [tmpData.texture].concat(tex_util.getDenseTexShape(shape)));
}
this.pendingRead.set(dataId, []);
if (!(dtype !== 'complex64')) return [3 /*break*/, 2];
// Create a fence and wait for it to resolve.
return [4 /*yield*/, this.gpgpu.createAndWaitForFence()];
case 1:
// Create a fence and wait for it to resolve.
_b.sent();
_b.label = 2;
case 2:
if (!(dtype === 'complex64')) return [3 /*break*/, 4];
return [4 /*yield*/, Promise.all([complexTensors.real.data(), complexTensors.imag.data()])];
case 3:
ps = _b.sent();
realValues = ps[0];
imagValues = ps[1];
vals = complex_util_1.mergeRealAndImagArrays(realValues, imagValues);
return [3 /*break*/, 5];
case 4:
if (buffer == null) {
vals = this.getValuesFromTexture(dataId);
}
else {
size = util.sizeFromShape(shape);
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
}
_b.label = 5;
case 5:
if (tmpDownloadTarget != null) {
this.disposeData(tmpDownloadTarget.dataId);
}
dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
subscribers = this.pendingRead.get(dataId);
this.pendingRead.delete(dataId);
// Notify all pending reads.
subscribers.forEach(function (resolve) { return resolve(dTypeVals); });
if (this.pendingDisposal.has(dataId)) {
this.pendingDisposal.delete(dataId);
this.disposeData(dataId);
this.pendingDeletes--;
}
return [2 /*return*/, dTypeVals];
}
});
});
};
MathBackendWebGL.prototype.checkNumericalProblems = function (values) {
if (values == null) {
return;
}
for (var i = 0; i < values.length; i++) {
var num = values[i];
if (!webgl_util.canBeRepresented(num)) {
if (environment_1.env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
throw Error("The value " + num + " cannot be represented with your " +
"current settings. Consider enabling float32 rendering: " +
"'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'");
}
throw Error("The value " + num + " cannot be represented on this device.");
}
}
};
MathBackendWebGL.prototype.getValuesFromTexture = function (dataId) {
var _a;
var _b = this.texData.get(dataId), shape = _b.shape, dtype = _b.dtype, isPacked = _b.isPacked;
var size = util.sizeFromShape(shape);
if (environment_1.env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
var tmpTarget = this.decode(dataId);
var tmpData_1 = this.texData.get(tmpTarget.dataId);
var vals_1 = (_a = this.gpgpu).downloadMatrixFromPackedTexture.apply(_a, [tmpData_1.texture].concat(tex_util.getDenseTexShape(shape))).subarray(0, size);
this.disposeData(tmpTarget.dataId);
return vals_1;
}
var shouldUsePackedProgram = environment_1.env().getBool('WEBGL_PACK') && isPacked === true;
var outputShape = shouldUsePackedProgram ? webgl_util.getShapeAs3D(shape) : shape;
var program = shouldUsePackedProgram ?
new encode_float_packed_gpu_1.EncodeFloatPackedProgram(outputShape) :
new encode_float_gpu_1.EncodeFloatProgram(outputShape);
var output = this.runWebGLProgram(program, [{ shape: outputShape, dtype: dtype, dataId: dataId }], 'float32');
var tmpData = this.texData.get(output.dataId);
var vals = this.gpgpu
.downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture, tmpData.texShape[0], tmpData.texShape[1])
.subarray(0, size);
this.disposeData(output.dataId);
return vals;
};
MathBackendWebGL.prototype.time = function (f) {
return __awaiter(this, void 0, void 0, function () {
var oldActiveTimers, newActiveTimers, outerMostTime, flattenedActiveTimerQueries, flattenedActiveTimerNames, kernelMs, res;
return __generator(this, function (_a) {
switch (_a.label) {
case 0:
oldActiveTimers = this.activeTimers;
newActiveTimers = [];
outerMostTime = false;
if (this.programTimersStack == null) {
this.programTimersStack = newActiveTimers;
outerMostTime = true;
}
else {
this.activeTimers.push(newActiveTimers);
}
this.activeTimers = newActiveTimers;
f();
flattenedActiveTimerQueries = util.flatten(this.activeTimers.map(function (d) { return d.query; }))
.filter(function (d) { return d != null; });
flattenedActiveTimerNames = util.flatten(this.activeTimers.map(function (d) { return d.name; }))
.filter(function (d) { return d != null; });
this.activeTimers = oldActiveTimers;
if (outerMostTime) {
this.programTimersStack = null;
}
return [4 /*yield*/, Promise.all(flattenedActiveTimerQueries)];
case 1:
kernelMs = _a.sent();
res = {
uploadWaitMs: this.uploadWaitMs,
downloadWaitMs: this.downloadWaitMs,
kernelMs: util.sum(kernelMs),
getExtraProfileInfo: function () {
return kernelMs.map(function (d, i) { return ({ name: flattenedActiveTimerNames[i], ms: d }); })
.map(function (d) { return d.name + ": " + d.ms; })
.join(', ');
},
wallMs: null // will be filled by the engine
};
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
return [2 /*return*/, res];
}
});
});
};
MathBackendWebGL.prototype.memory = function () {
return { unreliable: false, numBytesInGPU: this.numBytesInGPU };
};
MathBackendWebGL.prototype.startTimer = function () {
if (environment_1.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
return this.gpgpu.beginQuery();
}
return { startMs: util.now(), endMs: null };
};
MathBackendWebGL.prototype.endTimer = function (query) {
if (environment_1.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
this.gpgpu.endQuery();
return query;
}
query.endMs = util.now();
return query;
};
MathBackendWebGL.prototype.getQueryTime = function (query) {
return __awaiter(this, void 0, void 0, function () {
var timerQuery;
return __generator(this, function (_a) {
if (environment_1.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
return [2 /*return*/, this.gpgpu.waitForQueryAndGetTime(query)];
}
timerQuery = query;
return [2 /*return*/, timerQuery.endMs - timerQuery.startMs];
});
});
};
MathBackendWebGL.prototype.disposeData = function (dataId) {
if (this.pendingDisposal.has(dataId)) {
return;
}
if (this.pendingRead.has(dataId)) {
this.pendingDisposal.add(dataId);
this.pendingDeletes++;
return;
}
// No-op if already disposed.
if (!this.texData.has(dataId)) {
return;
}
this.releaseGPUData(dataId);
var complexTensors = this.texData.get(dataId).complexTensors;
if (complexTensors != null) {
complexTensors.real.dispose();
complexTensors.imag.dispose();
}
this.texData.delete(dataId);
};
MathBackendWebGL.prototype.releaseGPUData = function (dataId) {
var _a = this.texData.get(dataId), texture = _a.texture, dtype = _a.dtype, texShape = _a.texShape, usage = _a.usage, isPacked = _a.isPacked, slice = _a.slice;
var key = slice && slice.origDataId || dataId;
var refCount = this.dataRefCount.get(key);
if (refCount > 1) {
this.dataRefCount.set(key, refCount - 1);
}
else {
this.dataRefCount.delete(key);
if (texture != null) {
this.numBytesInGPU -= this.computeBytes(texShape, dtype);
this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
}
}
var texData = this.texData.get(dataId);
texData.texture = null;
texData.texShape = null;
texData.isPacked = false;
texData.slice = null;
};
MathBackendWebGL.prototype.getTexture = function (dataId) {
this.uploadToGPU(dataId);
return this.texData.get(dataId).texture;
};
/**
* Returns internal information for the specific data bucket. Used in unit
* tests.
*/
MathBackendWebGL.prototype.getDataInfo = function (dataId) {
return this.texData.get(dataId);
};
MathBackendWebGL.prototype.getCPUBackend = function () {
if (!environment_1.env().getBool('WEBGL_CPU_FORWARD')) {
return null;
}
if (this.cpuBackend == null) {
this.cpuBackend = engine_1.ENGINE.findBackend('cpu');
}
return this.cpuBackend;
};
/*
Tests whether all the inputs to an op are small and on the CPU. This heuristic
determines when it would be faster to execute a kernel on the CPU. WebGL
kernels opt into running this check and forwarding when appropriate.
TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more
sustainable strategy for optimizing backend execution of ops.
*/
MathBackendWebGL.prototype.shouldExecuteOnCPU = function (inputs, sizeThreshold) {
var _this = this;
if (sizeThreshold === void 0) { sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD; }
return this.getCPUBackend() != null &&
inputs.every(function (input) { return _this.texData.get(input.dataId).texture == null &&
input.size < sizeThreshold; });
};
MathBackendWebGL.prototype.getGPGPUContext = function () {
return this.gpgpu;
};
MathBackendWebGL.prototype.complex = function (real, imag) {
var result = this.makeOutput(real.shape, 'complex64');
var resultData = this.texData.get(result.dataId);
// The backend owns the reference to the underlying real and imaginary
// clones. These will explicitly get disposed when the complex tensor is
// disposed.
resultData.complexTensors = {
real: engine_1.ENGINE.keep(real.clone()),
imag: engine_1.ENGINE.keep(imag.clone())
};
return result;
};
MathBackendWebGL.prototype.real = function (input) {
var resultData = this.texData.get(input.dataId);
return resultData.complexTensors.real.clone();
};
MathBackendWebGL.prototype.imag = function (input) {
var resultData = this.texData.get(input.dataId);
return resultData.complexTensors.imag.clone();
};
MathBackendWebGL.prototype.slice = function (x, begin, size) {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.slice(x, begin, size);
}
// Short-circuit computation if the slice is zero-sized.
if (util.sizeFromShape(size) === 0) {
return tensor_ops_1.tensor([], size, x.dtype);
}
var isPacked = this.texData.get(x.dataId).isPacked;
var isContinous = slice_util.isSliceContinous(x.shape, begin, size);
if (isPacked || !isContinous) {
var program = environment_1.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new slice_packed_gpu_1.SlicePackedProgram(size) :
new slice_gpu_1.SliceProgram(size);
var customSetup = program.getCustomSetupFunc(begin);
return this.compileAndRun(program, [x], null, customSetup);
}
this.uploadToGPU(x.dataId);
return this.shallowSlice(x, begin, size);
};
MathBackendWebGL.prototype.shallowSlice = function (x, begin, size) {
var xTexData = this.texData.get(x.dataId);
var t = this.makeOutput(size, x.dtype);
var newTexData = this.texData.get(t.dataId);
// Copy texture data from the original tensor.
Object.assign(newTexData, xTexData);
newTexData.shape = size;
newTexData.dtype = x.dtype;
var flatOffset = slice_util.computeFlatOffset(begin, x.strides);
if (xTexData.slice) {
// We are slicing an already sliced tensor, so we have to accumulate
// the offset.
flatOffset += xTexData.slice.flatOffset;
}
newTexData.slice = {
flatOffset: flatOffset,
// Point to the original dataId, which is used to do ref counting.
origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
};
// Increase the ref count for that data bucket.
var refCount = this.dataRefCount.get(newTexData.slice.origDataId) || 1;
this.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
return t;
};
MathBackendWebGL.prototype.stridedSlice = function (x, begin, end, strides) {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.stridedSlice(x, begin, end, strides);
}
var outShape = slice_util.computeOutShape(begin, end, strides);
if (outShape.some(function (axis) { return axis === 0; })) {
return tensor_ops_1.tensor([], outShape);
}
var program = new strided_slice_gpu_1.StridedSliceProgram(begin, strides, outShape);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.reverse = function (x, axis) {
var program = environment_1.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new reverse_packed_gpu_1.ReversePackedProgram(x.shape, axis) :
new reverse_gpu_1.ReverseProgram(x.shape, axis);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.concat = function (tensors, axis) {
if (tensors[0].dtype === 'complex64') {
var reals = tensors.map(function (t) { return complex_ops_1.real(t); });
var imags = tensors.map(function (t) { return complex_ops_1.imag(t); });
return complex_ops_1.complex(this.concat(reals, axis), this.concat(imags, axis));
}
if (this.shouldExecuteOnCPU(tensors)) {
return this.cpuBackend.concat(tensors, axis);
}
if (tensors.length === 1) {
return tensors[0];
}
if (tensors.length > environment_1.env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
var midIndex = Math.floor(tensors.length / 2);
var leftSide = this.concat(tensors.slice(0, midIndex), axis);
var rightSide = this.concat(tensors.slice(midIndex), axis);
return this.concat([leftSide, rightSide], axis);
}
if (environment_1.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && tensors[0].rank > 1) {
var program_1 = new concat_packed_gpu_1.ConcatPackedProgram(tensors.map(function (t) { return t.shape; }), axis);
return this.compileAndRun(program_1, tensors);
}
// Any concat of n-dimensional tensors across any axis can be reduced to
// a concatenation of two-dimensional tensors across the axis 1 by first
// partitioning the axes of the original tensors into those less than the
// axis to be concatenated and the rest. Then reshape the tensors
// into a two-dimensional tensor by collapsing these two sets of axes and
// concatenate the resulting matrices across the axis 1, finally reshaping
// the result to have the proper shape.
var outShape = concat_util_1.computeOutShape(tensors.map(function (t) { return t.shape; }), axis);
var tensors2D = tensors.map(function (t) { return t.as2D(-1, util_1.sizeFromShape(t.shape.slice(axis))); });
var program = new concat_gpu_1.ConcatProgram(tensors2D.map(function (t) { return t.shape; }));
var res = this.compileAndRun(program, tensors2D);
return res.reshape(outShape);
};
MathBackendWebGL.prototype.neg = function (x) {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.neg(x);
}
if (environment_1.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
return this.packedUnaryOp(x, unary_op.NEG, x.dtype);
}
var program = new unaryop_gpu_1.UnaryOpProgram(x.shape, unary_op.NEG);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.batchMatMul = function (a, b, transposeA, transposeB) {
var outerShapeA = transposeA ? a.shape[2] : a.shape[1];
var outerShapeB = transposeB ? b.shape[1] : b.shape[2];
var sharedDim = transposeA ? a.shape[1] : a.shape[2];
var _a = a.shape, batch = _a[0];
// Since the matrices are vectors, it is faster to call mul().sum()
// because sum() is O(sqrt(N)) due to divide-and-conquer.
if ((outerShapeA === 1 || outerShapeB === 1) &&
sharedDim > exports.MATMUL_SHARED_DIM_THRESHOLD) {
if (transposeA) {
a = a.transpose([0, 2, 1]);
}
if (transposeB) {
b = b.transpose([0, 2, 1]);
}
var a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1);
var axis = outerShapeB === 1 ? 2 : 1;
var b3D = outerShapeB === 1 ? b.as3D(batch, 1, sharedDim) : b;
return this.multiply(a3D, b3D).sum(axis, true /* keepDims */);
}
var dtype = types_1.upcastType(a.dtype, b.dtype);
var program = new mulmat_packed_gpu_1.MatMulPackedProgram(a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB);
return this.compileAndRun(program, [a, b], dtype);
};
MathBackendWebGL.prototype.fusedBatchMatMul = function (_a) {
var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
var outerShapeA = transposeA ? a.shape[2] : a.shape[1];
var outerShapeB = transposeB ? b.shape[1] : b.shape[2];
var _b = a.shape, batch = _b[0];
var dtype = types_1.upcastType(a.dtype, b.dtype);
var hasBias = bias != null;
var hasPreluActivationWeights = preluActivationWeights != null;
var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
var program = new mulmat_packed_gpu_1.MatMulPackedProgram(a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights);
var inputs = [a, b];
if (bias) {
inputs.push(bias);
}
if (preluActivationWeights) {
inputs.push(preluActivationWeights);
}
return this.compileAndRun(program, inputs, dtype);
};
MathBackendWebGL.prototype.multiply = function (a, b) {
if (a.dtype === 'complex64') {
var aData = this.texData.get(a.dataId);
var bData = this.texData.get(b.dataId);
var realProgram = new binaryop_complex_gpu_1.BinaryOpComplexProgram(binaryop_complex_gpu.COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
var imagProgram = new binaryop_complex_gpu_1.BinaryOpComplexProgram(binaryop_complex_gpu.COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
var inputs = [
this.makeComplexComponentTensorInfo(a, aData.complexTensors.real),
this.makeComplexComponentTensorInfo(a, aData.complexTensors.imag),
this.makeComplexComponentTensorInfo(b, bData.complexTensors.real),
this.makeComplexComponentTensorInfo(b, bData.complexTensors.imag)
];
var real_1 = this.compileAndRun(realProgram, inputs);
var imag_1 = this.compileAndRun(imagProgram, inputs);
var complex_1 = this.complex(real_1, imag_1);
real_1.dispose();
imag_1.dispose();
return complex_1;
}
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.multiply(a, b);
}
if (environment_1.env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_gpu.MUL, a.dtype);
}
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape);
return this.compileAndRun(program, [a, b], a.dtype);
};
MathBackendWebGL.prototype.batchNormalization = function (x, mean, variance, varianceEpsilon, scale, offset) {
var inputs = [x, mean, variance];
var offsetShape = null;
if (offset != null) {
offsetShape = offset.shape;
inputs.push(offset);
}
var scaleShape = null;
if (scale != null) {
scaleShape = scale.shape;
inputs.push(scale);
}
if (environment_1.env().getBool('WEBGL_PACK_NORMALIZATION')) {
var batchNormPackedProgram = new batchnorm_packed_gpu_1.BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
return this.compileAndRun(batchNormPackedProgram, inputs);
}
var batchNormProgram = new batchnorm_gpu_1.BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
return this.compileAndRun(batchNormProgram, inputs);
};
MathBackendWebGL.prototype.localResponseNormalization4D = function (x, radius, bias, alpha, beta) {
var program = environment_1.env().getBool('WEBGL_PACK_NORMALIZATION') ?
new lrn_packed_gpu_1.LRNPackedProgram(x.shape, radius, bias, alpha, beta) :
new lrn_gpu_1.LRNProgram(x.shape, radius, bias, alpha, beta);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.LRNGrad = function (dy, inputImage, outputImage, depthRadius, bias, alpha, beta) {
var program = new lrn_grad_gpu_1.LRNGradProgram(inputImage.shape, depthRadius, bias, alpha, beta);
return this.compileAndRun(program, [inputImage, outputImage, dy]);
};
MathBackendWebGL.prototype.tile = function (x, reps) {
if (x.dtype === 'string') {
var data = this.readSync(x.dataId);
var decodedData = data.map(function (d) { return util.decodeString(d); });
var buf = array_ops_1.buffer(x.shape, x.dtype, decodedData);
return tile_impl_1.tile(buf, reps);
}
var program = new tile_gpu_1.TileProgram(x.shape, reps);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.pad = function (x, paddings, constantValue) {
var program = environment_1.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new pad_packed_gpu_1.PadPackedProgram(x.shape, paddings, constantValue) :
new pad_gpu_1.PadProgram(x.shape, paddings, constantValue);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.transpose = function (x, perm) {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.transpose(x, perm);
}
var program = environment_1.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new transpose_packed_gpu_1.TransposePackedProgram(x.shape, perm) :
new transpose_gpu_1.TransposeProgram(x.shape, perm);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.gather = function (x, indices, axis) {
if (this.shouldExecuteOnCPU([x, indices])) {
return this.cpuBackend.gather(x, indices, axis);
}
var program = new gather_gpu_1.GatherProgram(x.shape, indices.size, axis);
return this.compileAndRun(program, [x, indices]);
};
MathBackendWebGL.prototype.batchToSpaceND = function (x, blockShape, crops) {
util.assert(x.rank <= 4, function () { return 'batchToSpaceND for rank > 4 with a WebGL backend not ' +
'implemented yet'; });
var prod = blockShape.reduce(function (a, b) { return a * b; });
var reshaped = array_ops_util.getReshaped(x.shape, blockShape, prod);
var permuted = array_ops_util.getPermuted(reshaped.length, blockShape.length);
var reshapedPermuted = array_ops_util.getReshapedPermuted(x.shape, blockShape, prod);
var sliceBeginCoords = array_ops_util.getSliceBeginCoords(crops, blockShape.length);
var sliceSize = array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length);
return x.reshape(reshaped)
.transpose(permuted)
.reshape(reshapedPermuted)
.slice(sliceBeginCoords, sliceSize);
};
MathBackendWebGL.prototype.spaceToBatchND = function (x, blockShape, paddings) {
util.assert(x.rank <= 4, function () { return 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
'implemented yet'; });
var prod = blockShape.reduce(function (a, b) { return a * b; });
var completePaddings = [[0, 0]];
completePaddings.push.apply(completePaddings, paddings);
for (var i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}
var paddedX = x.pad(completePaddings);
var reshapedPaddedShape = array_ops_util.getReshaped(paddedX.shape, blockShape, prod, false);
var permutedReshapedPaddedPermutation = array_ops_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
var flattenShape = array_ops_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
return paddedX.reshape(reshapedPaddedShape)
.transpose(permutedReshapedPaddedPermutation)
.reshape(flattenShape);
};
MathBackendWebGL.prototype.reduce = function (x, reduceType, dtype) {
var batchSize = x.shape[0];
var inSize = x.shape[1];
var windowSize = reduce_util.computeOptimalWindowSize(inSize);
var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize };
var program = new reduce_gpu_1.ReduceProgram(reduceInfo, reduceType);
var output = this.compileAndRun(program, [x], dtype);
// No need to run another GPGPU program.
if (output.shape[1] === 1) {
return output;
}
return this.reduce(output, reduceType, dtype);
};
MathBackendWebGL.prototype.argReduce = function (x, reduceType, bestIndicesA) {
if (bestIndicesA === void 0) { bestIndicesA = null; }
var batchSize = x.shape[0];
var inSize = x.shape[1];
if (bestIndicesA != null) {
batchSize = bestIndicesA.shape[0];
inSize = bestIndicesA.shape[1];
}
var windowSize = reduce_util.computeOptimalWindowSize(inSize);
var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize };
var program = new argminmax_gpu_1.ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
var inputs = [x];
if (bestIndicesA != null) {
inputs.push(bestIndicesA);
}
var output = this.compileAndRun(program, inputs, 'int32');
// No need to run another GPGPU program.
if (output.shape[1] === 1) {
return output;
}
return this.argReduce(x, reduceType, output);
};
MathBackendWebGL.prototype.argReducePacked = function (x, reduceType, bestIndicesA) {
if (bestIndicesA === void 0) { bestIndicesA = null; }
var inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
var inSize = inShape[inShape.length - 1];
var windowSize = reduce_util.computeOptimalWindowSize(inSize);
var program = new argminmax_packed_gpu_1.ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
var inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
var output = this.compileAndRun(program, inputs, 'int32');
if (output.rank === x.rank) {
return this.argReducePacked(x, reduceType, output);
}
return output;
};
MathBackendWebGL.prototype.sum = function (x, axes) {
axis_util.assertAxesAreInnerMostDims('sum', axes, x.rank);
var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
var outputDType = types_1.sumOutType(x.dtype);
return this.reduce(a2D, 'sum', outputDType).reshape(outShape);
};
MathBackendWebGL.prototype.prod = function (x, axes) {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.prod(x, axes);
}
var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
var outputDType = types_1.sumOutType(