@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
952 lines • 91.7 kB
JavaScript
"use strict";
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
var __generator = (this && this.__generator) || function (thisArg, body) {
var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
function verb(n) { return function (v) { return step([n, v]); }; }
function step(op) {
if (f) throw new TypeError("Generator is already executing.");
while (_) try {
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
if (y = 0, t) op = [op[0] & 2, t.value];
switch (op[0]) {
case 0: case 1: t = op; break;
case 4: _.label++; return { value: op[1], done: false };
case 5: _.label++; y = op[1]; op = [0]; continue;
case 7: op = _.ops.pop(); _.trys.pop(); continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
if (t[2]) _.ops.pop();
_.trys.pop(); continue;
}
op = body.call(thisArg, _);
} catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
}
};
Object.defineProperty(exports, "__esModule", { value: true });
var canvas_util_1 = require("../canvas_util");
var environment_1 = require("../environment");
var globals_1 = require("../globals");
var log_1 = require("../log");
var array_ops_util = require("../ops/array_ops_util");
var axis_util = require("../ops/axis_util");
var concat_util_1 = require("../ops/concat_util");
var gather_nd_util = require("../ops/gather_nd_util");
var reduce_util = require("../ops/reduce_util");
var scatter_nd_util = require("../ops/scatter_nd_util");
var segment_util = require("../ops/segment_util");
var slice_util_1 = require("../ops/slice_util");
var softmax_1 = require("../ops/softmax");
var tensor_ops_1 = require("../ops/tensor_ops");
var tensor_1 = require("../tensor");
var types_1 = require("../types");
var util = require("../util");
var util_1 = require("../util");
var backend_1 = require("./backend");
var backend_util = require("./backend_util");
var complex_util_1 = require("./complex_util");
var non_max_suppression_impl_1 = require("./non_max_suppression_impl");
var split_shared_1 = require("./split_shared");
var topk_impl_1 = require("./topk_impl");
var argminmax_gpu_1 = require("./webgl/argminmax_gpu");
var avg_pool_backprop_gpu_1 = require("./webgl/avg_pool_backprop_gpu");
var batchnorm_gpu_1 = require("./webgl/batchnorm_gpu");
var batchnorm_packed_gpu_1 = require("./webgl/batchnorm_packed_gpu");
var binaryop_complex_gpu = require("./webgl/binaryop_complex_gpu");
var binaryop_complex_gpu_1 = require("./webgl/binaryop_complex_gpu");
var binaryop_gpu = require("./webgl/binaryop_gpu");
var binaryop_gpu_1 = require("./webgl/binaryop_gpu");
var binaryop_packed_gpu_1 = require("./webgl/binaryop_packed_gpu");
var clip_gpu_1 = require("./webgl/clip_gpu");
var clip_packed_gpu_1 = require("./webgl/clip_packed_gpu");
var complex_abs_gpu_1 = require("./webgl/complex_abs_gpu");
var concat_gpu_1 = require("./webgl/concat_gpu");
var conv_backprop_gpu_1 = require("./webgl/conv_backprop_gpu");
var conv_backprop_gpu_depthwise_1 = require("./webgl/conv_backprop_gpu_depthwise");
var conv_gpu_1 = require("./webgl/conv_gpu");
var conv_gpu_depthwise_1 = require("./webgl/conv_gpu_depthwise");
var conv_packed_gpu_depthwise_1 = require("./webgl/conv_packed_gpu_depthwise");
var crop_and_resize_gpu_1 = require("./webgl/crop_and_resize_gpu");
var cumsum_gpu_1 = require("./webgl/cumsum_gpu");
var depth_to_space_gpu_1 = require("./webgl/depth_to_space_gpu");
var encode_float_gpu_1 = require("./webgl/encode_float_gpu");
var fft_gpu = require("./webgl/fft_gpu");
var fft_gpu_1 = require("./webgl/fft_gpu");
var from_pixels_gpu_1 = require("./webgl/from_pixels_gpu");
var gather_gpu_1 = require("./webgl/gather_gpu");
var gather_nd_gpu_1 = require("./webgl/gather_nd_gpu");
var gpgpu_context_1 = require("./webgl/gpgpu_context");
var gpgpu_math = require("./webgl/gpgpu_math");
var im2col_gpu_1 = require("./webgl/im2col_gpu");
var lrn_gpu_1 = require("./webgl/lrn_gpu");
var lrn_grad_gpu_1 = require("./webgl/lrn_grad_gpu");
var max_pool_backprop_gpu_1 = require("./webgl/max_pool_backprop_gpu");
var mulmat_gpu_1 = require("./webgl/mulmat_gpu");
var mulmat_packed_gpu_1 = require("./webgl/mulmat_packed_gpu");
var multinomial_gpu_1 = require("./webgl/multinomial_gpu");
var onehot_gpu_1 = require("./webgl/onehot_gpu");
var pack_gpu_1 = require("./webgl/pack_gpu");
var pad_gpu_1 = require("./webgl/pad_gpu");
var pad_packed_gpu_1 = require("./webgl/pad_packed_gpu");
var pool_gpu_1 = require("./webgl/pool_gpu");
var reduce_gpu_1 = require("./webgl/reduce_gpu");
var reshape_packed_gpu_1 = require("./webgl/reshape_packed_gpu");
var resize_bilinear_backprop_gpu_1 = require("./webgl/resize_bilinear_backprop_gpu");
var resize_bilinear_gpu_1 = require("./webgl/resize_bilinear_gpu");
var resize_nearest_neighbor_backprop_gpu_1 = require("./webgl/resize_nearest_neighbor_backprop_gpu");
var resize_nearest_neighbor_gpu_1 = require("./webgl/resize_nearest_neighbor_gpu");
var reverse_gpu_1 = require("./webgl/reverse_gpu");
var scatter_gpu_1 = require("./webgl/scatter_gpu");
var segment_gpu_1 = require("./webgl/segment_gpu");
var select_gpu_1 = require("./webgl/select_gpu");
var slice_gpu_1 = require("./webgl/slice_gpu");
var strided_slice_gpu_1 = require("./webgl/strided_slice_gpu");
var tex_util = require("./webgl/tex_util");
var tex_util_1 = require("./webgl/tex_util");
var texture_manager_1 = require("./webgl/texture_manager");
var tile_gpu_1 = require("./webgl/tile_gpu");
var transpose_gpu_1 = require("./webgl/transpose_gpu");
var unary_op = require("./webgl/unaryop_gpu");
var unaryop_gpu_1 = require("./webgl/unaryop_gpu");
var unary_packed_op = require("./webgl/unaryop_packed_gpu");
var unaryop_packed_gpu_1 = require("./webgl/unaryop_packed_gpu");
var unpack_gpu_1 = require("./webgl/unpack_gpu");
var webgl_util = require("./webgl/webgl_util");
var where_impl_1 = require("./where_impl");
function mapActivationToShaderProgram(activation, packed) {
if (packed === void 0) { packed = false; }
if (activation === 'linear') {
if (packed) {
return unary_packed_op.LINEAR;
}
return unary_op.LINEAR;
}
else if (activation === 'relu') {
if (packed) {
return unary_packed_op.RELU;
}
return unary_op.RELU;
}
throw new Error("Activation " + activation + " has not been implemented for the WebGL backend.");
}
var CPU_HANDOFF_SIZE_THRESHOLD = 10;
exports.MATMUL_SHARED_DIM_THRESHOLD = 1000;
var MathBackendWebGL = (function () {
function MathBackendWebGL(gpgpu, delayedStorage) {
if (delayedStorage === void 0) { delayedStorage = true; }
this.gpgpu = gpgpu;
this.delayedStorage = delayedStorage;
this.pendingRead = new WeakMap();
this.pendingDisposal = new WeakSet();
this.dataRefCount = new WeakMap();
this.lruDataGPU = [];
this.numBytesInGPU = 0;
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
this.binaryCache = {};
this.disposed = false;
if (environment_1.ENV.get('WEBGL_VERSION') < 1) {
throw new Error('WebGL is not supported on this device');
}
if (gpgpu == null) {
var gl = canvas_util_1.getWebGLContext(environment_1.ENV.get('WEBGL_VERSION'));
this.gpgpu = new gpgpu_context_1.GPGPUContext(gl);
this.canvas = gl.canvas;
this.gpgpuCreatedLocally = true;
}
else {
this.gpgpuCreatedLocally = false;
this.canvas = gpgpu.gl.canvas;
}
this.textureManager = new texture_manager_1.TextureManager(this.gpgpu);
}
MathBackendWebGL.prototype.register = function (dataId, shape, dtype) {
if (this.texData.has(dataId)) {
throw new Error('Data buffer is already registered');
}
this.texData.set(dataId, { shape: shape, dtype: dtype });
};
MathBackendWebGL.prototype.setDataMover = function (dataMover) {
this.texData = new backend_1.DataStorage(dataMover);
};
MathBackendWebGL.prototype.fromPixels = function (pixels, numChannels) {
if (pixels == null) {
throw new Error('pixels passed to tf.fromPixels() can not be null');
}
var texShape = [pixels.height, pixels.width];
var outShape = [pixels.height, pixels.width, numChannels];
if (!(pixels instanceof HTMLVideoElement) &&
!(pixels instanceof HTMLImageElement) &&
!(pixels instanceof HTMLCanvasElement) &&
!(pixels instanceof ImageData)) {
throw new Error('pixels passed to tf.fromPixels() must be either an ' +
"HTMLVideoElement, HTMLImageElement, HTMLCanvasElement or " +
("ImageData, but was " + pixels.constructor.name));
}
if (pixels instanceof HTMLVideoElement) {
if (this.fromPixels2DContext == null) {
if (!environment_1.ENV.get('IS_BROWSER')) {
throw new Error('Can\'t read pixels from HTMLImageElement outside the browser.');
}
if (document.readyState !== 'complete') {
throw new Error('The DOM is not ready yet. Please call tf.fromPixels() ' +
'once the DOM is ready. One way to do that is to add an event ' +
'listener for `DOMContentLoaded` on the document object');
}
this.fromPixels2DContext =
document.createElement('canvas').getContext('2d');
}
this.fromPixels2DContext.canvas.width = pixels.width;
this.fromPixels2DContext.canvas.height = pixels.height;
this.fromPixels2DContext.drawImage(pixels, 0, 0, pixels.width, pixels.height);
pixels = this.fromPixels2DContext.canvas;
}
var tempPixelHandle = this.makeTensorHandle(texShape, 'int32');
this.texData.get(tempPixelHandle.dataId).usage = tex_util_1.TextureUsage.PIXELS;
this.gpgpu.uploadPixelDataToTexture(this.getTexture(tempPixelHandle.dataId), pixels);
var program = new from_pixels_gpu_1.FromPixelsProgram(outShape);
var res = this.compileAndRun(program, [tempPixelHandle]);
this.disposeData(tempPixelHandle.dataId);
return res;
};
MathBackendWebGL.prototype.makeTensorHandle = function (shape, dtype) {
var dataId = {};
this.register(dataId, shape, dtype);
return { dataId: dataId, shape: shape, dtype: dtype };
};
MathBackendWebGL.prototype.write = function (dataId, values) {
if (values == null) {
throw new Error('MathBackendWebGL.write(): values can not be null');
}
if (environment_1.ENV.get('DEBUG')) {
for (var i = 0; i < values.length; i++) {
var num = values[i];
if (!webgl_util.canBeRepresented(num)) {
throw Error("The value " + num + " cannot be represented on this device.");
}
}
}
var texData = this.texData.get(dataId);
var texture = texData.texture, texShape = texData.texShape, usage = texData.usage, dtype = texData.dtype, isPacked = texData.isPacked;
if (dtype === 'complex64') {
throw new Error("Cannot write to a complex64 dtype. " +
"Please use tf.complex(real, imag).");
}
if (texture != null) {
this.releaseTexture(dataId, texture, texShape, usage, isPacked);
texData.texture = null;
texData.texShape = null;
}
texData.usage = tex_util_1.TextureUsage.UPLOAD;
texData.values = values;
if (!this.delayedStorage) {
this.uploadToGPU(dataId);
}
};
MathBackendWebGL.prototype.readSync = function (dataId) {
var texData = this.texData.get(dataId);
var values = texData.values, dtype = texData.dtype, complexTensors = texData.complexTensors, slice = texData.slice, shape = texData.shape;
if (slice != null) {
var program = new unaryop_gpu_1.UnaryOpProgram(shape, unary_op.CLONE);
var res = this.compileAndRun(program, [{ dataId: dataId, shape: shape, dtype: dtype }]);
return this.readSync(res.dataId);
}
if (values != null) {
return this.convertAndCacheOnCPU(dataId);
}
if (dtype === 'string') {
return values;
}
var shouldTimeProgram = this.activeTimers != null;
var start;
if (shouldTimeProgram) {
start = performance.now();
}
var result;
if (dtype === 'complex64') {
var realValues = complexTensors.real.dataSync();
var imagValues = complexTensors.imag.dataSync();
result = complex_util_1.mergeRealAndImagArrays(realValues, imagValues);
}
else {
result = this.getValuesFromTexture(dataId);
}
if (shouldTimeProgram) {
this.downloadWaitMs += performance.now() - start;
}
return this.convertAndCacheOnCPU(dataId, result);
};
MathBackendWebGL.prototype.read = function (dataId) {
return __awaiter(this, void 0, void 0, function () {
var _a, _b, subscribers_1, texData, texture, values, texShape, isPacked, shape, slice, dtype, program, res, width, height, bufferOrTexture, vals, size, batch, rows, cols, dTypeVals, subscribers;
return __generator(this, function (_c) {
switch (_c.label) {
case 0:
if (this.pendingRead.has(dataId)) {
subscribers_1 = this.pendingRead.get(dataId);
return [2, new Promise(function (resolve) { return subscribers_1.push(resolve); })];
}
texData = this.texData.get(dataId);
texture = texData.texture, values = texData.values, texShape = texData.texShape, isPacked = texData.isPacked, shape = texData.shape, slice = texData.slice, dtype = texData.dtype;
if (slice != null) {
program = new unaryop_gpu_1.UnaryOpProgram(shape, unary_op.CLONE);
res = this.compileAndRun(program, [{ dataId: dataId, shape: shape, dtype: dtype }]);
return [2, this.read(res.dataId)];
}
if (values != null) {
return [2, this.convertAndCacheOnCPU(dataId)];
}
this.pendingRead.set(dataId, []);
if (!environment_1.ENV.get('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&
environment_1.ENV.get('WEBGL_VERSION') === 2) {
throw new Error("tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and " +
"WEBGL_VERSION=2 not yet supported.");
}
width = texShape[1];
height = texShape[0];
if (isPacked) {
_a = tex_util.getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]), width = _a[0], height = _a[1];
}
bufferOrTexture = this.gpgpu.maybeCreateBufferFromTexture(texture, height, width);
return [4, this.gpgpu.createAndWaitForFence()];
case 1:
_c.sent();
if (bufferOrTexture instanceof WebGLTexture) {
vals = this.getValuesFromTexture(dataId);
}
else {
size = util.sizeFromShape(shape);
if (isPacked) {
batch = webgl_util.getBatchDim(shape);
rows = 1, cols = 1;
if (shape.length) {
_b = webgl_util.getRowsCols(shape), rows = _b[0], cols = _b[1];
}
vals = this.gpgpu
.downloadPackedMatrixFromBuffer(bufferOrTexture, batch, rows, cols, texShape[0], texShape[1])
.subarray(0, size);
}
else {
vals = this.gpgpu
.downloadFloat32MatrixFromBuffer(bufferOrTexture, texShape[0], texShape[1])
.subarray(0, size);
}
}
dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
subscribers = this.pendingRead.get(dataId);
this.pendingRead.delete(dataId);
subscribers.forEach(function (resolve) { return resolve(dTypeVals); });
if (this.pendingDisposal.has(dataId)) {
this.pendingDisposal.delete(dataId);
this.disposeData(dataId);
}
return [2, dTypeVals];
}
});
});
};
MathBackendWebGL.prototype.getValuesFromTexture = function (dataId) {
var _a;
var _b = this.texData.get(dataId), shape = _b.shape, dtype = _b.dtype, texture = _b.texture, texShape = _b.texShape;
var size = util.sizeFromShape(shape);
if (environment_1.ENV.get('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
if (this.texData.get(dataId).isPacked) {
var batch = webgl_util.getBatchDim(shape);
var rows = 1, cols = 1;
if (shape.length) {
_a = webgl_util.getRowsCols(shape), rows = _a[0], cols = _a[1];
}
return this.gpgpu
.downloadMatrixFromPackedTexture(texture, batch, rows, cols, texShape[0], texShape[1])
.subarray(0, size);
}
else {
return this.gpgpu
.downloadFloat32MatrixFromOutputTexture(texture, texShape[0], texShape[1])
.subarray(0, size);
}
}
var tmpTarget = this.makeTensorHandle(shape, 'float32');
tmpTarget.size = util_1.sizeFromShape(shape);
this.texData.get(tmpTarget.dataId).usage = tex_util_1.TextureUsage.DOWNLOAD;
var program = new encode_float_gpu_1.EncodeFloatProgram(shape);
var pageToCpu = false;
this.compileAndRun(program, [{ shape: shape, dtype: dtype, dataId: dataId }], tmpTarget, null, pageToCpu);
var tmpData = this.texData.get(tmpTarget.dataId);
var vals = this.gpgpu
.downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture, tmpData.texShape[0], tmpData.texShape[1])
.subarray(0, size);
this.disposeData(tmpTarget.dataId);
return vals;
};
MathBackendWebGL.prototype.time = function (f) {
return __awaiter(this, void 0, void 0, function () {
var oldActiveTimers, newActiveTimers, outerMostTime, flattenedActiveTimerQueries, flattenedActiveTimerNames, kernelMs, res;
return __generator(this, function (_a) {
switch (_a.label) {
case 0:
oldActiveTimers = this.activeTimers;
newActiveTimers = [];
outerMostTime = false;
if (this.programTimersStack == null) {
this.programTimersStack = newActiveTimers;
outerMostTime = true;
}
else {
this.activeTimers.push(newActiveTimers);
}
this.activeTimers = newActiveTimers;
f();
flattenedActiveTimerQueries = util.flatten(this.activeTimers.map(function (d) { return d.query; }))
.filter(function (d) { return d != null; });
flattenedActiveTimerNames = util.flatten(this.activeTimers.map(function (d) { return d.name; }))
.filter(function (d) { return d != null; });
this.activeTimers = oldActiveTimers;
if (outerMostTime) {
this.programTimersStack = null;
}
return [4, Promise.all(flattenedActiveTimerQueries)];
case 1:
kernelMs = _a.sent();
res = {
uploadWaitMs: this.uploadWaitMs,
downloadWaitMs: this.downloadWaitMs,
kernelMs: util.sum(kernelMs),
getExtraProfileInfo: function () {
return kernelMs.map(function (d, i) { return ({ name: flattenedActiveTimerNames[i], ms: d }); })
.map(function (d) { return d.name + ": " + d.ms; })
.join(', ');
},
wallMs: null
};
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
return [2, res];
}
});
});
};
MathBackendWebGL.prototype.memory = function () {
return { unreliable: false, numBytesInGPU: this.numBytesInGPU };
};
MathBackendWebGL.prototype.startTimer = function () {
if (environment_1.ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
return this.gpgpu.beginQuery();
}
return { startMs: performance.now(), endMs: null };
};
MathBackendWebGL.prototype.endTimer = function (query) {
if (environment_1.ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
this.gpgpu.endQuery();
return query;
}
query.endMs = performance.now();
return query;
};
MathBackendWebGL.prototype.getQueryTime = function (query) {
return __awaiter(this, void 0, void 0, function () {
var timerQuery;
return __generator(this, function (_a) {
if (environment_1.ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
return [2, this.gpgpu.waitForQueryAndGetTime(query)];
}
timerQuery = query;
return [2, timerQuery.endMs - timerQuery.startMs];
});
});
};
MathBackendWebGL.prototype.disposeData = function (dataId) {
if (this.pendingDisposal.has(dataId)) {
return;
}
if (this.pendingRead.has(dataId)) {
this.pendingDisposal.add(dataId);
return;
}
if (this.texData.has(dataId)) {
var _a = this.texData.get(dataId), texture = _a.texture, texShape = _a.texShape, usage = _a.usage, complexTensors = _a.complexTensors, isPacked = _a.isPacked, slice = _a.slice;
if (texture != null) {
var key = slice && slice.origDataId || dataId;
var refCount = this.dataRefCount.get(key);
if (refCount > 1) {
this.dataRefCount.set(key, refCount - 1);
}
else {
this.dataRefCount.delete(key);
this.releaseTexture(dataId, texture, texShape, usage, isPacked);
this.texData.delete(dataId);
}
}
if (complexTensors != null) {
complexTensors.real.dispose();
complexTensors.imag.dispose();
}
}
};
MathBackendWebGL.prototype.getTexture = function (dataId) {
this.uploadToGPU(dataId);
return this.texData.get(dataId).texture;
};
MathBackendWebGL.prototype.getCPUBackend = function () {
if (!environment_1.ENV.get('WEBGL_CPU_FORWARD')) {
return null;
}
if (this.cpuBackend == null) {
this.cpuBackend = environment_1.ENV.findBackend('cpu');
}
return this.cpuBackend;
};
MathBackendWebGL.prototype.shouldExecuteOnCPU = function (inputs, sizeThreshold) {
var _this = this;
if (sizeThreshold === void 0) { sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD; }
return this.getCPUBackend() != null &&
inputs.every(function (input) { return _this.texData.get(input.dataId).texture == null &&
input.size < sizeThreshold; });
};
MathBackendWebGL.prototype.getGPGPUContext = function () {
return this.gpgpu;
};
MathBackendWebGL.prototype.getCanvas = function () {
return this.canvas;
};
MathBackendWebGL.prototype.complex = function (real, imag) {
var result = this.makeOutputArray(real.shape, 'complex64');
var resultData = this.texData.get(result.dataId);
resultData.complexTensors = {
real: environment_1.ENV.engine.keep(real.clone()),
imag: environment_1.ENV.engine.keep(imag.clone())
};
return result;
};
MathBackendWebGL.prototype.real = function (input) {
var resultData = this.texData.get(input.dataId);
return resultData.complexTensors.real.clone();
};
MathBackendWebGL.prototype.imag = function (input) {
var resultData = this.texData.get(input.dataId);
return resultData.complexTensors.imag.clone();
};
MathBackendWebGL.prototype.slice = function (x, begin, size) {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.slice(x, begin, size);
}
var isPacked = this.texData.get(x.dataId).isPacked;
var isContinous = slice_util_1.isSliceContinous(x.shape, begin, size);
if (isPacked || !isContinous) {
var program = new slice_gpu_1.SliceProgram(size);
var customSetup = program.getCustomSetupFunc(begin);
return this.compileAndRun(program, [x], null, customSetup);
}
this.uploadToGPU(x.dataId);
return this.shallowSlice(x, begin, size);
};
MathBackendWebGL.prototype.shallowSlice = function (x, begin, size) {
var xTexData = this.texData.get(x.dataId);
var t = tensor_1.Tensor.make(size, {}, xTexData.dtype);
var newTexData = this.texData.get(t.dataId);
Object.assign(newTexData, xTexData);
newTexData.shape = size;
var flatOffset = slice_util_1.computeFlatOffset(begin, x.strides);
if (xTexData.slice) {
flatOffset += xTexData.slice.flatOffset;
}
newTexData.slice = {
flatOffset: flatOffset,
origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
};
var refCount = this.dataRefCount.get(newTexData.slice.origDataId) || 1;
this.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
return t;
};
MathBackendWebGL.prototype.stridedSlice = function (x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.stridedSlice(x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
}
var _a = slice_util_1.getStridedSlicedInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask), beginIndex = _a[0], size = _a[1], shrinkAxis = _a[2];
var shape = size.filter(function (v, index) { return shrinkAxis.indexOf(index) === -1; });
if (shape.some(function (axis) { return axis === 0; })) {
return tensor_ops_1.tensor([], shape);
}
var program = new strided_slice_gpu_1.StridedSliceProgram(beginIndex, strides, size, shrinkAxis);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.reverse = function (x, axis) {
var program = new reverse_gpu_1.ReverseProgram(x.shape, axis);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.concat = function (tensors, axis) {
if (this.shouldExecuteOnCPU(tensors)) {
return this.cpuBackend.concat(tensors, axis);
}
if (tensors.length === 1) {
return tensors[0];
}
if (tensors.length > environment_1.ENV.get('WEBGL_MAX_TEXTURES_IN_SHADER')) {
var midIndex = Math.floor(tensors.length / 2);
var leftSide = this.concat(tensors.slice(0, midIndex), axis);
var rightSide = this.concat(tensors.slice(midIndex), axis);
return this.concat([leftSide, rightSide], axis);
}
var outShape = concat_util_1.computeOutShape(tensors.map(function (t) { return t.shape; }), axis);
var tensors2D = tensors.map(function (t) { return t.as2D(-1, util_1.sizeFromShape(t.shape.slice(axis))); });
var program = new concat_gpu_1.ConcatProgram(tensors2D.map(function (t) { return t.shape; }));
var res = this.compileAndRun(program, tensors2D);
return res.reshape(outShape);
};
MathBackendWebGL.prototype.neg = function (x) {
var program = new unaryop_gpu_1.UnaryOpProgram(x.shape, unary_op.NEG);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.batchMatMul = function (a, b, transposeA, transposeB) {
var outerShapeA = transposeA ? a.shape[2] : a.shape[1];
var outerShapeB = transposeB ? b.shape[1] : b.shape[2];
var sharedDim = transposeA ? a.shape[1] : a.shape[2];
var _a = a.shape, batch = _a[0];
if ((outerShapeA === 1 || outerShapeB === 1) &&
sharedDim > exports.MATMUL_SHARED_DIM_THRESHOLD) {
if (transposeA) {
a = a.transpose([0, 2, 1]);
}
if (transposeB) {
b = b.transpose([0, 2, 1]);
}
var a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1);
var axis = outerShapeB === 1 ? 2 : 1;
var b3D = outerShapeB === 1 ? b.as3D(batch, 1, sharedDim) : b;
return this.multiply(a3D, b3D).sum(axis, true);
}
var dtype = types_1.upcastType(a.dtype, b.dtype);
if (batch === 1) {
var aSqueezed = a.as2D(a.shape[1], a.shape[2]);
var bSqueezed = b.as2D(b.shape[1], b.shape[2]);
var program = new mulmat_packed_gpu_1.MatMulPackedProgram(aSqueezed.shape, bSqueezed.shape, [outerShapeA, outerShapeB], transposeA, transposeB);
var output = this.makePackedTensor(program.outputShape, dtype);
var result = this.compileAndRun(program, [aSqueezed, bSqueezed], output);
return result.reshape([1, result.shape[0], result.shape[1]]);
}
else {
var program = new mulmat_gpu_1.MatMulProgram(a.shape, b.shape, transposeA, transposeB);
var output = this.makeOutputArray(program.outputShape, dtype);
return this.compileAndRun(program, [a, b], output);
}
};
MathBackendWebGL.prototype.fusedBatchMatMul = function (a, b, transposeA, transposeB, bias, activation) {
var outerShapeA = transposeA ? a.shape[2] : a.shape[1];
var outerShapeB = transposeB ? b.shape[1] : b.shape[2];
var _a = a.shape, batch = _a[0];
var dtype = types_1.upcastType(a.dtype, b.dtype);
if (batch === 1) {
var aSqueezed = a.as2D(a.shape[1], a.shape[2]);
var bSqueezed = b.as2D(b.shape[1], b.shape[2]);
var program = new mulmat_packed_gpu_1.MatMulPackedProgram(aSqueezed.shape, bSqueezed.shape, [outerShapeA, outerShapeB], transposeA, transposeB, !!bias, activation ? mapActivationToShaderProgram(activation, true) : null);
var output = this.makePackedTensor(program.outputShape, dtype);
var inputs = [aSqueezed, bSqueezed];
if (bias) {
inputs.push(bias);
}
var result = this.compileAndRun(program, inputs, output);
return result.reshape([1, result.shape[0], result.shape[1]]);
}
else {
var program = new mulmat_gpu_1.MatMulProgram(a.shape, b.shape, transposeA, transposeB, !!bias, activation ? mapActivationToShaderProgram(activation) : null);
var inputs = [a, b];
if (bias) {
inputs.push(bias);
}
var output = this.makeOutputArray(program.outputShape, dtype);
return this.compileAndRun(program, inputs, output);
}
};
MathBackendWebGL.prototype.multiply = function (a, b) {
if (a.dtype === 'complex64') {
var aData = this.texData.get(a.dataId);
var bData = this.texData.get(b.dataId);
var realProgram = new binaryop_complex_gpu_1.BinaryOpComplexProgram(binaryop_complex_gpu.COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
var imagProgram = new binaryop_complex_gpu_1.BinaryOpComplexProgram(binaryop_complex_gpu.COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
var inputs = [
this.makeComplexComponentTensorHandle(a, aData.complexTensors.real),
this.makeComplexComponentTensorHandle(a, aData.complexTensors.imag),
this.makeComplexComponentTensorHandle(b, bData.complexTensors.real),
this.makeComplexComponentTensorHandle(b, bData.complexTensors.imag)
];
var real = this.compileAndRun(realProgram, inputs);
var imag = this.compileAndRun(imagProgram, inputs);
var complex = this.complex(real, imag);
real.dispose();
imag.dispose();
return complex;
}
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.multiply(a, b);
}
if (environment_1.ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
return this.packedBinaryOp(a, b, binaryop_gpu.MUL, a.dtype);
}
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape);
var output = this.makeOutputArray(program.outputShape, a.dtype);
return this.compileAndRun(program, [a, b], output);
};
MathBackendWebGL.prototype.batchNormalization = function (x, mean, variance, varianceEpsilon, scale, offset) {
var inputs = [x, mean, variance];
var offsetShape = null;
if (offset != null) {
offsetShape = offset.shape;
inputs.push(offset);
}
var scaleShape = null;
if (scale != null) {
scaleShape = scale.shape;
inputs.push(scale);
}
if (environment_1.ENV.get('WEBGL_PACK_BATCHNORMALIZATION')) {
var batchNormPackedProgram = new batchnorm_packed_gpu_1.BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
return this.compileAndRun(batchNormPackedProgram, inputs);
}
var batchNormProgram = new batchnorm_gpu_1.BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
return this.compileAndRun(batchNormProgram, inputs);
};
MathBackendWebGL.prototype.localResponseNormalization4D = function (x, radius, bias, alpha, beta) {
var program = new lrn_gpu_1.LRNProgram(x.shape, radius, bias, alpha, beta);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.LRNGrad = function (dy, inputImage, outputImage, depthRadius, bias, alpha, beta) {
var program = new lrn_grad_gpu_1.LRNGradProgram(inputImage.shape, depthRadius, bias, alpha, beta);
return this.compileAndRun(program, [inputImage, outputImage, dy]);
};
MathBackendWebGL.prototype.tile = function (x, reps) {
var program = new tile_gpu_1.TileProgram(x.shape, reps);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.pad = function (x, paddings, constantValue) {
var program = environment_1.ENV.get('WEBGL_PACK_ARRAY_OPERATIONS') ?
new pad_packed_gpu_1.PadPackedProgram(x.shape, paddings, constantValue) :
new pad_gpu_1.PadProgram(x.shape, paddings, constantValue);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.transpose = function (x, perm) {
var program = new transpose_gpu_1.TransposeProgram(x.shape, perm);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.gather = function (x, indices, axis) {
var program = new gather_gpu_1.GatherProgram(x.shape, indices.size, axis);
return this.compileAndRun(program, [x, indices]);
};
MathBackendWebGL.prototype.batchToSpaceND = function (x, blockShape, crops) {
util.assert(x.rank <= 4, 'batchToSpaceND for rank > 4 with a WebGL backend not implemented yet');
var prod = blockShape.reduce(function (a, b) { return a * b; });
var reshaped = array_ops_util.getReshaped(x.shape, blockShape, prod);
var permuted = array_ops_util.getPermuted(reshaped.length, blockShape.length);
var reshapedPermuted = array_ops_util.getReshapedPermuted(x.shape, blockShape, prod);
var sliceBeginCoords = array_ops_util.getSliceBeginCoords(crops, blockShape.length);
var sliceSize = array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length);
return x.reshape(reshaped)
.transpose(permuted)
.reshape(reshapedPermuted)
.slice(sliceBeginCoords, sliceSize);
};
MathBackendWebGL.prototype.spaceToBatchND = function (x, blockShape, paddings) {
util.assert(x.rank <= 4, 'spaceToBatchND for rank > 4 with a WebGL backend not implemented yet');
var prod = blockShape.reduce(function (a, b) { return a * b; });
var completePaddings = [[0, 0]];
completePaddings.push.apply(completePaddings, paddings);
for (var i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}
var paddedX = x.pad(completePaddings);
var reshapedPaddedShape = array_ops_util.getReshaped(paddedX.shape, blockShape, prod, false);
var permutedReshapedPaddedPermutation = array_ops_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
var flattenShape = array_ops_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
return paddedX.reshape(reshapedPaddedShape)
.transpose(permutedReshapedPaddedPermutation)
.reshape(flattenShape);
};
MathBackendWebGL.prototype.reduce = function (x, reduceType, dtype) {
var batchSize = x.shape[0];
var inSize = x.shape[1];
var windowSize = reduce_util.computeOptimalWindowSize(inSize);
var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize };
var program = new reduce_gpu_1.ReduceProgram(reduceInfo, reduceType);
var _a = program.outputShape, rows = _a[0], cols = _a[1];
var output = this.makeOutputArray([rows, cols], dtype);
this.compileAndRun(program, [x], output);
if (output.shape[1] === 1) {
return output;
}
return this.reduce(output, reduceType, dtype);
};
MathBackendWebGL.prototype.argReduce = function (x, reduceType, bestIndicesA) {
if (bestIndicesA === void 0) { bestIndicesA = null; }
var batchSize = x.shape[0];
var inSize = x.shape[1];
if (bestIndicesA != null) {
batchSize = bestIndicesA.shape[0];
inSize = bestIndicesA.shape[1];
}
var windowSize = reduce_util.computeOptimalWindowSize(inSize);
var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize };
var program = new argminmax_gpu_1.ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
var _a = program.outputShape, rows = _a[0], cols = _a[1];
var output = this.makeOutputArray([rows, cols], 'int32');
var inputs = [x];
if (bestIndicesA != null) {
inputs.push(bestIndicesA);
}
this.compileAndRun(program, inputs, output);
if (output.shape[1] === 1) {
return output;
}
return this.argReduce(x, reduceType, output);
};
MathBackendWebGL.prototype.sum = function (x, axes) {
axis_util.assertAxesAreInnerMostDims('sum', axes, x.rank);
var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
var outputDType = types_1.sumOutType(x.dtype);
return this.reduce(a2D, 'sum', outputDType).reshape(outShape);
};
MathBackendWebGL.prototype.prod = function (x, axes) {
var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
var outputDType = types_1.sumOutType(x.dtype);
return this.reduce(a2D, 'prod', outputDType).reshape(outShape);
};
MathBackendWebGL.prototype.unsortedSegmentSum = function (x, segmentIds, numSegments) {
var axis = 0;
var permutation = axis_util.getAxesPermutation([axis], x.rank);
var permutedX = x;
if (permutation != null) {
permutedX = x.transpose(permutation);
axis = axis_util.getInnerMostAxes(1, x.rank)[0];
}
var outShape = segment_util.computeOutShape(permutedX.shape, axis, numSegments);
var inSize = util.sizeFromShape([permutedX.shape[axis]]);
var a2D = permutedX.as2D(-1, inSize);
var outputDType = types_1.sumOutType(x.dtype);
var result = this.segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments)
.reshape(outShape);
if (permutation != null) {
result = result.transpose(axis_util.getUndoAxesPermutation(permutation));
}
return result;
};
MathBackendWebGL.prototype.segOpCompute = function (x, segOpType, segmentIds, dtype, numSegments) {
var batchSize = x.shape[0];
var inSize = x.shape[1];
var windowSize = segment_util.segOpComputeOptimalWindowSize(inSize, numSegments);
var segOpInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize, numSegments: numSegments };
var program = new segment_gpu_1.SegmentOpProgram(segOpInfo, segOpType);
var _a = program.outputShape, rows = _a[0], cols = _a[1];
var output = this.makeOutputArray([rows, cols], dtype);
this.compileAndRun(program, [x, segmentIds], output);
if (output.shape[1] === numSegments) {
return output;
}
segmentIds = tensor_ops_1.range(0, numSegments).tile([inSize / windowSize]);
return this.segOpCompute(output, segOpType, segmentIds, dtype, numSegments);
};
MathBackendWebGL.prototype.argMin = function (x, axis) {
var axes = [axis];
axis_util.assertAxesAreInnerMostDims('argMin', axes, x.rank);
var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
return this.argReduce(a2D, 'min').reshape(outShape);
};
MathBackendWebGL.prototype.argMax = function (x, axis) {
var axes = [axis];
axis_util.assertAxesAreInnerMostDims('argMax', axes, x.rank);
var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
return this.argReduce(a2D, 'max').reshape(outShape);
};
MathBackendWebGL.prototype.cumsum = function (x, axis, exclusive, reverse) {
if (axis !== x.rank - 1) {
throw new Error("WebGL cumsum shader expects an inner-most axis=" + (x.rank - 1) + " " +
("but got axis=" + axis));
}
var program = new cumsum_gpu_1.CumSumProgram(x.shape, exclusive, reverse);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.equal = function (a, b) {
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.EQUAL, a.shape, b.shape);
var output = this.makeOutputArray(program.outputShape, 'bool');
return this.compileAndRun(program, [a, b], output);
};
MathBackendWebGL.prototype.notEqual = function (a, b) {
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.NOT_EQUAL, a.shape, b.shape);
var output = this.makeOutputArray(program.outputShape, 'bool');
return this.compileAndRun(program, [a, b], output);
};
MathBackendWebGL.prototype.less = function (a, b) {
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.less(a, b);
}
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.LESS, a.shape, b.shape);
var output = this.makeOutputArray(program.outputShape, 'bool');
return this.compileAndRun(program, [a, b], output);
};
MathBackendWebGL.prototype.lessEqual = function (a, b) {
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.LESS_EQUAL, a.shape, b.shape);
var output = this.makeOutputArray(program.outputShape, 'bool');
return this.compileAndRun(program, [a, b], output);
};
MathBackendWebGL.prototype.greater = function (a, b) {
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.greater(a, b);
}
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.GREATER, a.shape, b.shape);
var output = this.makeOutputArray(program.outputShape, 'bool');
return this.compileAndRun(program, [a, b], output);
};
MathBackendWebGL.prototype.greaterEqual = function (a, b) {
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.GREATER_EQUAL, a.shape, b.shape);
var output = this.makeOutputArray(program.outputShape, 'bool');
return this.compileAndRun(program, [a, b], output);
};
MathBackendWebGL.prototype.logicalNot = function (x) {
var program = new unaryop_gpu_1.UnaryOpProgram(x.shape, unary_op.LOGICAL_NOT);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL.prototype.logicalAnd = function (a, b) {
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.LOGICAL_AND, a.shape, b.shape);
var output = this.makeOutputArray(program.outputShape, 'bool');
return this.compileAndRun(program, [a, b], output);
};
MathBackendWebGL.prototype.logicalOr = function (a, b) {
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.LOGICAL_OR, a.shape, b.shape);
var output = this.makeOutputArray(program.outputShape, 'bool');
return this.compileAndRun(program, [a, b], output);
};
MathBackendWebGL.prototype.select = function (condition, a, b) {
var program = new select_gpu_1.SelectProgram(condition.rank, a.shape, a.rank);
var output = this.makeOutputArray(program.outputShape, types_1.upcastType(a.dtype, b.dtype));
return this.compileAndRun(program, [condition, a, b], output);
};
MathBackendWebGL.prototype.where = function (condition) {
log_1.warn('tf.where() in webgl locks the UI thread. ' +
'Call tf.whereAsync() instead');
var condVals = condition.dataSync();
return where_impl_1.whereImpl(condition.shape, condVals);
};
MathBackendWebGL.prototype.topk = function (x, k, sorted) {
var xVals = x.dataSync();
return topk_impl_1.topkImpl(xVals, x.shape, x.dtype, k, sorted);
};
MathBackendWebGL.prototype.min = function (x, axes) {
axis_util.assertAxesAreInnerMostDims('min', axes, x.rank);
var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize =