UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

952 lines 91.7 kB
"use strict"; var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; var __generator = (this && this.__generator) || function (thisArg, body) { var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; function verb(n) { return function (v) { return step([n, v]); }; } function step(op) { if (f) throw new TypeError("Generator is already executing."); while (_) try { if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; if (y = 0, t) op = [op[0] & 2, t.value]; switch (op[0]) { case 0: case 1: t = op; break; case 4: _.label++; return { value: op[1], done: false }; case 5: _.label++; y = op[1]; op = [0]; continue; case 7: op = _.ops.pop(); _.trys.pop(); continue; default: if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } if (t[2]) _.ops.pop(); _.trys.pop(); continue; } op = body.call(thisArg, _); } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; } }; Object.defineProperty(exports, "__esModule", { value: true }); var canvas_util_1 = require("../canvas_util"); var environment_1 = require("../environment"); var globals_1 = require("../globals"); var log_1 = require("../log"); var array_ops_util = require("../ops/array_ops_util"); var axis_util = require("../ops/axis_util"); var concat_util_1 = require("../ops/concat_util"); var gather_nd_util = require("../ops/gather_nd_util"); var reduce_util = require("../ops/reduce_util"); var scatter_nd_util = require("../ops/scatter_nd_util"); var segment_util = require("../ops/segment_util"); var slice_util_1 = require("../ops/slice_util"); var softmax_1 = require("../ops/softmax"); var tensor_ops_1 = require("../ops/tensor_ops"); var tensor_1 = require("../tensor"); var types_1 = require("../types"); var util = require("../util"); var util_1 = require("../util"); var backend_1 = require("./backend"); var backend_util = require("./backend_util"); var complex_util_1 = require("./complex_util"); var non_max_suppression_impl_1 = require("./non_max_suppression_impl"); var split_shared_1 = require("./split_shared"); var topk_impl_1 = require("./topk_impl"); var argminmax_gpu_1 = require("./webgl/argminmax_gpu"); var avg_pool_backprop_gpu_1 = require("./webgl/avg_pool_backprop_gpu"); var batchnorm_gpu_1 = require("./webgl/batchnorm_gpu"); var batchnorm_packed_gpu_1 = require("./webgl/batchnorm_packed_gpu"); var binaryop_complex_gpu = require("./webgl/binaryop_complex_gpu"); var binaryop_complex_gpu_1 = require("./webgl/binaryop_complex_gpu"); var binaryop_gpu = require("./webgl/binaryop_gpu"); var binaryop_gpu_1 = require("./webgl/binaryop_gpu"); var binaryop_packed_gpu_1 = require("./webgl/binaryop_packed_gpu"); var clip_gpu_1 = require("./webgl/clip_gpu"); var clip_packed_gpu_1 = require("./webgl/clip_packed_gpu"); var complex_abs_gpu_1 = require("./webgl/complex_abs_gpu"); var concat_gpu_1 = require("./webgl/concat_gpu"); var conv_backprop_gpu_1 = require("./webgl/conv_backprop_gpu"); var conv_backprop_gpu_depthwise_1 = require("./webgl/conv_backprop_gpu_depthwise"); var conv_gpu_1 = require("./webgl/conv_gpu"); var conv_gpu_depthwise_1 = require("./webgl/conv_gpu_depthwise"); var conv_packed_gpu_depthwise_1 = require("./webgl/conv_packed_gpu_depthwise"); var crop_and_resize_gpu_1 = require("./webgl/crop_and_resize_gpu"); var cumsum_gpu_1 = require("./webgl/cumsum_gpu"); var depth_to_space_gpu_1 = require("./webgl/depth_to_space_gpu"); var encode_float_gpu_1 = require("./webgl/encode_float_gpu"); var fft_gpu = require("./webgl/fft_gpu"); var fft_gpu_1 = require("./webgl/fft_gpu"); var from_pixels_gpu_1 = require("./webgl/from_pixels_gpu"); var gather_gpu_1 = require("./webgl/gather_gpu"); var gather_nd_gpu_1 = require("./webgl/gather_nd_gpu"); var gpgpu_context_1 = require("./webgl/gpgpu_context"); var gpgpu_math = require("./webgl/gpgpu_math"); var im2col_gpu_1 = require("./webgl/im2col_gpu"); var lrn_gpu_1 = require("./webgl/lrn_gpu"); var lrn_grad_gpu_1 = require("./webgl/lrn_grad_gpu"); var max_pool_backprop_gpu_1 = require("./webgl/max_pool_backprop_gpu"); var mulmat_gpu_1 = require("./webgl/mulmat_gpu"); var mulmat_packed_gpu_1 = require("./webgl/mulmat_packed_gpu"); var multinomial_gpu_1 = require("./webgl/multinomial_gpu"); var onehot_gpu_1 = require("./webgl/onehot_gpu"); var pack_gpu_1 = require("./webgl/pack_gpu"); var pad_gpu_1 = require("./webgl/pad_gpu"); var pad_packed_gpu_1 = require("./webgl/pad_packed_gpu"); var pool_gpu_1 = require("./webgl/pool_gpu"); var reduce_gpu_1 = require("./webgl/reduce_gpu"); var reshape_packed_gpu_1 = require("./webgl/reshape_packed_gpu"); var resize_bilinear_backprop_gpu_1 = require("./webgl/resize_bilinear_backprop_gpu"); var resize_bilinear_gpu_1 = require("./webgl/resize_bilinear_gpu"); var resize_nearest_neighbor_backprop_gpu_1 = require("./webgl/resize_nearest_neighbor_backprop_gpu"); var resize_nearest_neighbor_gpu_1 = require("./webgl/resize_nearest_neighbor_gpu"); var reverse_gpu_1 = require("./webgl/reverse_gpu"); var scatter_gpu_1 = require("./webgl/scatter_gpu"); var segment_gpu_1 = require("./webgl/segment_gpu"); var select_gpu_1 = require("./webgl/select_gpu"); var slice_gpu_1 = require("./webgl/slice_gpu"); var strided_slice_gpu_1 = require("./webgl/strided_slice_gpu"); var tex_util = require("./webgl/tex_util"); var tex_util_1 = require("./webgl/tex_util"); var texture_manager_1 = require("./webgl/texture_manager"); var tile_gpu_1 = require("./webgl/tile_gpu"); var transpose_gpu_1 = require("./webgl/transpose_gpu"); var unary_op = require("./webgl/unaryop_gpu"); var unaryop_gpu_1 = require("./webgl/unaryop_gpu"); var unary_packed_op = require("./webgl/unaryop_packed_gpu"); var unaryop_packed_gpu_1 = require("./webgl/unaryop_packed_gpu"); var unpack_gpu_1 = require("./webgl/unpack_gpu"); var webgl_util = require("./webgl/webgl_util"); var where_impl_1 = require("./where_impl"); function mapActivationToShaderProgram(activation, packed) { if (packed === void 0) { packed = false; } if (activation === 'linear') { if (packed) { return unary_packed_op.LINEAR; } return unary_op.LINEAR; } else if (activation === 'relu') { if (packed) { return unary_packed_op.RELU; } return unary_op.RELU; } throw new Error("Activation " + activation + " has not been implemented for the WebGL backend."); } var CPU_HANDOFF_SIZE_THRESHOLD = 10; exports.MATMUL_SHARED_DIM_THRESHOLD = 1000; var MathBackendWebGL = (function () { function MathBackendWebGL(gpgpu, delayedStorage) { if (delayedStorage === void 0) { delayedStorage = true; } this.gpgpu = gpgpu; this.delayedStorage = delayedStorage; this.pendingRead = new WeakMap(); this.pendingDisposal = new WeakSet(); this.dataRefCount = new WeakMap(); this.lruDataGPU = []; this.numBytesInGPU = 0; this.uploadWaitMs = 0; this.downloadWaitMs = 0; this.binaryCache = {}; this.disposed = false; if (environment_1.ENV.get('WEBGL_VERSION') < 1) { throw new Error('WebGL is not supported on this device'); } if (gpgpu == null) { var gl = canvas_util_1.getWebGLContext(environment_1.ENV.get('WEBGL_VERSION')); this.gpgpu = new gpgpu_context_1.GPGPUContext(gl); this.canvas = gl.canvas; this.gpgpuCreatedLocally = true; } else { this.gpgpuCreatedLocally = false; this.canvas = gpgpu.gl.canvas; } this.textureManager = new texture_manager_1.TextureManager(this.gpgpu); } MathBackendWebGL.prototype.register = function (dataId, shape, dtype) { if (this.texData.has(dataId)) { throw new Error('Data buffer is already registered'); } this.texData.set(dataId, { shape: shape, dtype: dtype }); }; MathBackendWebGL.prototype.setDataMover = function (dataMover) { this.texData = new backend_1.DataStorage(dataMover); }; MathBackendWebGL.prototype.fromPixels = function (pixels, numChannels) { if (pixels == null) { throw new Error('pixels passed to tf.fromPixels() can not be null'); } var texShape = [pixels.height, pixels.width]; var outShape = [pixels.height, pixels.width, numChannels]; if (!(pixels instanceof HTMLVideoElement) && !(pixels instanceof HTMLImageElement) && !(pixels instanceof HTMLCanvasElement) && !(pixels instanceof ImageData)) { throw new Error('pixels passed to tf.fromPixels() must be either an ' + "HTMLVideoElement, HTMLImageElement, HTMLCanvasElement or " + ("ImageData, but was " + pixels.constructor.name)); } if (pixels instanceof HTMLVideoElement) { if (this.fromPixels2DContext == null) { if (!environment_1.ENV.get('IS_BROWSER')) { throw new Error('Can\'t read pixels from HTMLImageElement outside the browser.'); } if (document.readyState !== 'complete') { throw new Error('The DOM is not ready yet. Please call tf.fromPixels() ' + 'once the DOM is ready. One way to do that is to add an event ' + 'listener for `DOMContentLoaded` on the document object'); } this.fromPixels2DContext = document.createElement('canvas').getContext('2d'); } this.fromPixels2DContext.canvas.width = pixels.width; this.fromPixels2DContext.canvas.height = pixels.height; this.fromPixels2DContext.drawImage(pixels, 0, 0, pixels.width, pixels.height); pixels = this.fromPixels2DContext.canvas; } var tempPixelHandle = this.makeTensorHandle(texShape, 'int32'); this.texData.get(tempPixelHandle.dataId).usage = tex_util_1.TextureUsage.PIXELS; this.gpgpu.uploadPixelDataToTexture(this.getTexture(tempPixelHandle.dataId), pixels); var program = new from_pixels_gpu_1.FromPixelsProgram(outShape); var res = this.compileAndRun(program, [tempPixelHandle]); this.disposeData(tempPixelHandle.dataId); return res; }; MathBackendWebGL.prototype.makeTensorHandle = function (shape, dtype) { var dataId = {}; this.register(dataId, shape, dtype); return { dataId: dataId, shape: shape, dtype: dtype }; }; MathBackendWebGL.prototype.write = function (dataId, values) { if (values == null) { throw new Error('MathBackendWebGL.write(): values can not be null'); } if (environment_1.ENV.get('DEBUG')) { for (var i = 0; i < values.length; i++) { var num = values[i]; if (!webgl_util.canBeRepresented(num)) { throw Error("The value " + num + " cannot be represented on this device."); } } } var texData = this.texData.get(dataId); var texture = texData.texture, texShape = texData.texShape, usage = texData.usage, dtype = texData.dtype, isPacked = texData.isPacked; if (dtype === 'complex64') { throw new Error("Cannot write to a complex64 dtype. " + "Please use tf.complex(real, imag)."); } if (texture != null) { this.releaseTexture(dataId, texture, texShape, usage, isPacked); texData.texture = null; texData.texShape = null; } texData.usage = tex_util_1.TextureUsage.UPLOAD; texData.values = values; if (!this.delayedStorage) { this.uploadToGPU(dataId); } }; MathBackendWebGL.prototype.readSync = function (dataId) { var texData = this.texData.get(dataId); var values = texData.values, dtype = texData.dtype, complexTensors = texData.complexTensors, slice = texData.slice, shape = texData.shape; if (slice != null) { var program = new unaryop_gpu_1.UnaryOpProgram(shape, unary_op.CLONE); var res = this.compileAndRun(program, [{ dataId: dataId, shape: shape, dtype: dtype }]); return this.readSync(res.dataId); } if (values != null) { return this.convertAndCacheOnCPU(dataId); } if (dtype === 'string') { return values; } var shouldTimeProgram = this.activeTimers != null; var start; if (shouldTimeProgram) { start = performance.now(); } var result; if (dtype === 'complex64') { var realValues = complexTensors.real.dataSync(); var imagValues = complexTensors.imag.dataSync(); result = complex_util_1.mergeRealAndImagArrays(realValues, imagValues); } else { result = this.getValuesFromTexture(dataId); } if (shouldTimeProgram) { this.downloadWaitMs += performance.now() - start; } return this.convertAndCacheOnCPU(dataId, result); }; MathBackendWebGL.prototype.read = function (dataId) { return __awaiter(this, void 0, void 0, function () { var _a, _b, subscribers_1, texData, texture, values, texShape, isPacked, shape, slice, dtype, program, res, width, height, bufferOrTexture, vals, size, batch, rows, cols, dTypeVals, subscribers; return __generator(this, function (_c) { switch (_c.label) { case 0: if (this.pendingRead.has(dataId)) { subscribers_1 = this.pendingRead.get(dataId); return [2, new Promise(function (resolve) { return subscribers_1.push(resolve); })]; } texData = this.texData.get(dataId); texture = texData.texture, values = texData.values, texShape = texData.texShape, isPacked = texData.isPacked, shape = texData.shape, slice = texData.slice, dtype = texData.dtype; if (slice != null) { program = new unaryop_gpu_1.UnaryOpProgram(shape, unary_op.CLONE); res = this.compileAndRun(program, [{ dataId: dataId, shape: shape, dtype: dtype }]); return [2, this.read(res.dataId)]; } if (values != null) { return [2, this.convertAndCacheOnCPU(dataId)]; } this.pendingRead.set(dataId, []); if (!environment_1.ENV.get('WEBGL_DOWNLOAD_FLOAT_ENABLED') && environment_1.ENV.get('WEBGL_VERSION') === 2) { throw new Error("tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and " + "WEBGL_VERSION=2 not yet supported."); } width = texShape[1]; height = texShape[0]; if (isPacked) { _a = tex_util.getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]), width = _a[0], height = _a[1]; } bufferOrTexture = this.gpgpu.maybeCreateBufferFromTexture(texture, height, width); return [4, this.gpgpu.createAndWaitForFence()]; case 1: _c.sent(); if (bufferOrTexture instanceof WebGLTexture) { vals = this.getValuesFromTexture(dataId); } else { size = util.sizeFromShape(shape); if (isPacked) { batch = webgl_util.getBatchDim(shape); rows = 1, cols = 1; if (shape.length) { _b = webgl_util.getRowsCols(shape), rows = _b[0], cols = _b[1]; } vals = this.gpgpu .downloadPackedMatrixFromBuffer(bufferOrTexture, batch, rows, cols, texShape[0], texShape[1]) .subarray(0, size); } else { vals = this.gpgpu .downloadFloat32MatrixFromBuffer(bufferOrTexture, texShape[0], texShape[1]) .subarray(0, size); } } dTypeVals = this.convertAndCacheOnCPU(dataId, vals); subscribers = this.pendingRead.get(dataId); this.pendingRead.delete(dataId); subscribers.forEach(function (resolve) { return resolve(dTypeVals); }); if (this.pendingDisposal.has(dataId)) { this.pendingDisposal.delete(dataId); this.disposeData(dataId); } return [2, dTypeVals]; } }); }); }; MathBackendWebGL.prototype.getValuesFromTexture = function (dataId) { var _a; var _b = this.texData.get(dataId), shape = _b.shape, dtype = _b.dtype, texture = _b.texture, texShape = _b.texShape; var size = util.sizeFromShape(shape); if (environment_1.ENV.get('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { if (this.texData.get(dataId).isPacked) { var batch = webgl_util.getBatchDim(shape); var rows = 1, cols = 1; if (shape.length) { _a = webgl_util.getRowsCols(shape), rows = _a[0], cols = _a[1]; } return this.gpgpu .downloadMatrixFromPackedTexture(texture, batch, rows, cols, texShape[0], texShape[1]) .subarray(0, size); } else { return this.gpgpu .downloadFloat32MatrixFromOutputTexture(texture, texShape[0], texShape[1]) .subarray(0, size); } } var tmpTarget = this.makeTensorHandle(shape, 'float32'); tmpTarget.size = util_1.sizeFromShape(shape); this.texData.get(tmpTarget.dataId).usage = tex_util_1.TextureUsage.DOWNLOAD; var program = new encode_float_gpu_1.EncodeFloatProgram(shape); var pageToCpu = false; this.compileAndRun(program, [{ shape: shape, dtype: dtype, dataId: dataId }], tmpTarget, null, pageToCpu); var tmpData = this.texData.get(tmpTarget.dataId); var vals = this.gpgpu .downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture, tmpData.texShape[0], tmpData.texShape[1]) .subarray(0, size); this.disposeData(tmpTarget.dataId); return vals; }; MathBackendWebGL.prototype.time = function (f) { return __awaiter(this, void 0, void 0, function () { var oldActiveTimers, newActiveTimers, outerMostTime, flattenedActiveTimerQueries, flattenedActiveTimerNames, kernelMs, res; return __generator(this, function (_a) { switch (_a.label) { case 0: oldActiveTimers = this.activeTimers; newActiveTimers = []; outerMostTime = false; if (this.programTimersStack == null) { this.programTimersStack = newActiveTimers; outerMostTime = true; } else { this.activeTimers.push(newActiveTimers); } this.activeTimers = newActiveTimers; f(); flattenedActiveTimerQueries = util.flatten(this.activeTimers.map(function (d) { return d.query; })) .filter(function (d) { return d != null; }); flattenedActiveTimerNames = util.flatten(this.activeTimers.map(function (d) { return d.name; })) .filter(function (d) { return d != null; }); this.activeTimers = oldActiveTimers; if (outerMostTime) { this.programTimersStack = null; } return [4, Promise.all(flattenedActiveTimerQueries)]; case 1: kernelMs = _a.sent(); res = { uploadWaitMs: this.uploadWaitMs, downloadWaitMs: this.downloadWaitMs, kernelMs: util.sum(kernelMs), getExtraProfileInfo: function () { return kernelMs.map(function (d, i) { return ({ name: flattenedActiveTimerNames[i], ms: d }); }) .map(function (d) { return d.name + ": " + d.ms; }) .join(', '); }, wallMs: null }; this.uploadWaitMs = 0; this.downloadWaitMs = 0; return [2, res]; } }); }); }; MathBackendWebGL.prototype.memory = function () { return { unreliable: false, numBytesInGPU: this.numBytesInGPU }; }; MathBackendWebGL.prototype.startTimer = function () { if (environment_1.ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { return this.gpgpu.beginQuery(); } return { startMs: performance.now(), endMs: null }; }; MathBackendWebGL.prototype.endTimer = function (query) { if (environment_1.ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { this.gpgpu.endQuery(); return query; } query.endMs = performance.now(); return query; }; MathBackendWebGL.prototype.getQueryTime = function (query) { return __awaiter(this, void 0, void 0, function () { var timerQuery; return __generator(this, function (_a) { if (environment_1.ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { return [2, this.gpgpu.waitForQueryAndGetTime(query)]; } timerQuery = query; return [2, timerQuery.endMs - timerQuery.startMs]; }); }); }; MathBackendWebGL.prototype.disposeData = function (dataId) { if (this.pendingDisposal.has(dataId)) { return; } if (this.pendingRead.has(dataId)) { this.pendingDisposal.add(dataId); return; } if (this.texData.has(dataId)) { var _a = this.texData.get(dataId), texture = _a.texture, texShape = _a.texShape, usage = _a.usage, complexTensors = _a.complexTensors, isPacked = _a.isPacked, slice = _a.slice; if (texture != null) { var key = slice && slice.origDataId || dataId; var refCount = this.dataRefCount.get(key); if (refCount > 1) { this.dataRefCount.set(key, refCount - 1); } else { this.dataRefCount.delete(key); this.releaseTexture(dataId, texture, texShape, usage, isPacked); this.texData.delete(dataId); } } if (complexTensors != null) { complexTensors.real.dispose(); complexTensors.imag.dispose(); } } }; MathBackendWebGL.prototype.getTexture = function (dataId) { this.uploadToGPU(dataId); return this.texData.get(dataId).texture; }; MathBackendWebGL.prototype.getCPUBackend = function () { if (!environment_1.ENV.get('WEBGL_CPU_FORWARD')) { return null; } if (this.cpuBackend == null) { this.cpuBackend = environment_1.ENV.findBackend('cpu'); } return this.cpuBackend; }; MathBackendWebGL.prototype.shouldExecuteOnCPU = function (inputs, sizeThreshold) { var _this = this; if (sizeThreshold === void 0) { sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD; } return this.getCPUBackend() != null && inputs.every(function (input) { return _this.texData.get(input.dataId).texture == null && input.size < sizeThreshold; }); }; MathBackendWebGL.prototype.getGPGPUContext = function () { return this.gpgpu; }; MathBackendWebGL.prototype.getCanvas = function () { return this.canvas; }; MathBackendWebGL.prototype.complex = function (real, imag) { var result = this.makeOutputArray(real.shape, 'complex64'); var resultData = this.texData.get(result.dataId); resultData.complexTensors = { real: environment_1.ENV.engine.keep(real.clone()), imag: environment_1.ENV.engine.keep(imag.clone()) }; return result; }; MathBackendWebGL.prototype.real = function (input) { var resultData = this.texData.get(input.dataId); return resultData.complexTensors.real.clone(); }; MathBackendWebGL.prototype.imag = function (input) { var resultData = this.texData.get(input.dataId); return resultData.complexTensors.imag.clone(); }; MathBackendWebGL.prototype.slice = function (x, begin, size) { if (this.shouldExecuteOnCPU([x])) { return this.cpuBackend.slice(x, begin, size); } var isPacked = this.texData.get(x.dataId).isPacked; var isContinous = slice_util_1.isSliceContinous(x.shape, begin, size); if (isPacked || !isContinous) { var program = new slice_gpu_1.SliceProgram(size); var customSetup = program.getCustomSetupFunc(begin); return this.compileAndRun(program, [x], null, customSetup); } this.uploadToGPU(x.dataId); return this.shallowSlice(x, begin, size); }; MathBackendWebGL.prototype.shallowSlice = function (x, begin, size) { var xTexData = this.texData.get(x.dataId); var t = tensor_1.Tensor.make(size, {}, xTexData.dtype); var newTexData = this.texData.get(t.dataId); Object.assign(newTexData, xTexData); newTexData.shape = size; var flatOffset = slice_util_1.computeFlatOffset(begin, x.strides); if (xTexData.slice) { flatOffset += xTexData.slice.flatOffset; } newTexData.slice = { flatOffset: flatOffset, origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId }; var refCount = this.dataRefCount.get(newTexData.slice.origDataId) || 1; this.dataRefCount.set(newTexData.slice.origDataId, refCount + 1); return t; }; MathBackendWebGL.prototype.stridedSlice = function (x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) { if (this.shouldExecuteOnCPU([x])) { return this.cpuBackend.stridedSlice(x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); } var _a = slice_util_1.getStridedSlicedInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask), beginIndex = _a[0], size = _a[1], shrinkAxis = _a[2]; var shape = size.filter(function (v, index) { return shrinkAxis.indexOf(index) === -1; }); if (shape.some(function (axis) { return axis === 0; })) { return tensor_ops_1.tensor([], shape); } var program = new strided_slice_gpu_1.StridedSliceProgram(beginIndex, strides, size, shrinkAxis); return this.compileAndRun(program, [x]); }; MathBackendWebGL.prototype.reverse = function (x, axis) { var program = new reverse_gpu_1.ReverseProgram(x.shape, axis); return this.compileAndRun(program, [x]); }; MathBackendWebGL.prototype.concat = function (tensors, axis) { if (this.shouldExecuteOnCPU(tensors)) { return this.cpuBackend.concat(tensors, axis); } if (tensors.length === 1) { return tensors[0]; } if (tensors.length > environment_1.ENV.get('WEBGL_MAX_TEXTURES_IN_SHADER')) { var midIndex = Math.floor(tensors.length / 2); var leftSide = this.concat(tensors.slice(0, midIndex), axis); var rightSide = this.concat(tensors.slice(midIndex), axis); return this.concat([leftSide, rightSide], axis); } var outShape = concat_util_1.computeOutShape(tensors.map(function (t) { return t.shape; }), axis); var tensors2D = tensors.map(function (t) { return t.as2D(-1, util_1.sizeFromShape(t.shape.slice(axis))); }); var program = new concat_gpu_1.ConcatProgram(tensors2D.map(function (t) { return t.shape; })); var res = this.compileAndRun(program, tensors2D); return res.reshape(outShape); }; MathBackendWebGL.prototype.neg = function (x) { var program = new unaryop_gpu_1.UnaryOpProgram(x.shape, unary_op.NEG); return this.compileAndRun(program, [x]); }; MathBackendWebGL.prototype.batchMatMul = function (a, b, transposeA, transposeB) { var outerShapeA = transposeA ? a.shape[2] : a.shape[1]; var outerShapeB = transposeB ? b.shape[1] : b.shape[2]; var sharedDim = transposeA ? a.shape[1] : a.shape[2]; var _a = a.shape, batch = _a[0]; if ((outerShapeA === 1 || outerShapeB === 1) && sharedDim > exports.MATMUL_SHARED_DIM_THRESHOLD) { if (transposeA) { a = a.transpose([0, 2, 1]); } if (transposeB) { b = b.transpose([0, 2, 1]); } var a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1); var axis = outerShapeB === 1 ? 2 : 1; var b3D = outerShapeB === 1 ? b.as3D(batch, 1, sharedDim) : b; return this.multiply(a3D, b3D).sum(axis, true); } var dtype = types_1.upcastType(a.dtype, b.dtype); if (batch === 1) { var aSqueezed = a.as2D(a.shape[1], a.shape[2]); var bSqueezed = b.as2D(b.shape[1], b.shape[2]); var program = new mulmat_packed_gpu_1.MatMulPackedProgram(aSqueezed.shape, bSqueezed.shape, [outerShapeA, outerShapeB], transposeA, transposeB); var output = this.makePackedTensor(program.outputShape, dtype); var result = this.compileAndRun(program, [aSqueezed, bSqueezed], output); return result.reshape([1, result.shape[0], result.shape[1]]); } else { var program = new mulmat_gpu_1.MatMulProgram(a.shape, b.shape, transposeA, transposeB); var output = this.makeOutputArray(program.outputShape, dtype); return this.compileAndRun(program, [a, b], output); } }; MathBackendWebGL.prototype.fusedBatchMatMul = function (a, b, transposeA, transposeB, bias, activation) { var outerShapeA = transposeA ? a.shape[2] : a.shape[1]; var outerShapeB = transposeB ? b.shape[1] : b.shape[2]; var _a = a.shape, batch = _a[0]; var dtype = types_1.upcastType(a.dtype, b.dtype); if (batch === 1) { var aSqueezed = a.as2D(a.shape[1], a.shape[2]); var bSqueezed = b.as2D(b.shape[1], b.shape[2]); var program = new mulmat_packed_gpu_1.MatMulPackedProgram(aSqueezed.shape, bSqueezed.shape, [outerShapeA, outerShapeB], transposeA, transposeB, !!bias, activation ? mapActivationToShaderProgram(activation, true) : null); var output = this.makePackedTensor(program.outputShape, dtype); var inputs = [aSqueezed, bSqueezed]; if (bias) { inputs.push(bias); } var result = this.compileAndRun(program, inputs, output); return result.reshape([1, result.shape[0], result.shape[1]]); } else { var program = new mulmat_gpu_1.MatMulProgram(a.shape, b.shape, transposeA, transposeB, !!bias, activation ? mapActivationToShaderProgram(activation) : null); var inputs = [a, b]; if (bias) { inputs.push(bias); } var output = this.makeOutputArray(program.outputShape, dtype); return this.compileAndRun(program, inputs, output); } }; MathBackendWebGL.prototype.multiply = function (a, b) { if (a.dtype === 'complex64') { var aData = this.texData.get(a.dataId); var bData = this.texData.get(b.dataId); var realProgram = new binaryop_complex_gpu_1.BinaryOpComplexProgram(binaryop_complex_gpu.COMPLEX_MULTIPLY.REAL, a.shape, b.shape); var imagProgram = new binaryop_complex_gpu_1.BinaryOpComplexProgram(binaryop_complex_gpu.COMPLEX_MULTIPLY.IMAG, a.shape, b.shape); var inputs = [ this.makeComplexComponentTensorHandle(a, aData.complexTensors.real), this.makeComplexComponentTensorHandle(a, aData.complexTensors.imag), this.makeComplexComponentTensorHandle(b, bData.complexTensors.real), this.makeComplexComponentTensorHandle(b, bData.complexTensors.imag) ]; var real = this.compileAndRun(realProgram, inputs); var imag = this.compileAndRun(imagProgram, inputs); var complex = this.complex(real, imag); real.dispose(); imag.dispose(); return complex; } if (this.shouldExecuteOnCPU([a, b])) { return this.cpuBackend.multiply(a, b); } if (environment_1.ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_gpu.MUL, a.dtype); } var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape); var output = this.makeOutputArray(program.outputShape, a.dtype); return this.compileAndRun(program, [a, b], output); }; MathBackendWebGL.prototype.batchNormalization = function (x, mean, variance, varianceEpsilon, scale, offset) { var inputs = [x, mean, variance]; var offsetShape = null; if (offset != null) { offsetShape = offset.shape; inputs.push(offset); } var scaleShape = null; if (scale != null) { scaleShape = scale.shape; inputs.push(scale); } if (environment_1.ENV.get('WEBGL_PACK_BATCHNORMALIZATION')) { var batchNormPackedProgram = new batchnorm_packed_gpu_1.BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); return this.compileAndRun(batchNormPackedProgram, inputs); } var batchNormProgram = new batchnorm_gpu_1.BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); return this.compileAndRun(batchNormProgram, inputs); }; MathBackendWebGL.prototype.localResponseNormalization4D = function (x, radius, bias, alpha, beta) { var program = new lrn_gpu_1.LRNProgram(x.shape, radius, bias, alpha, beta); return this.compileAndRun(program, [x]); }; MathBackendWebGL.prototype.LRNGrad = function (dy, inputImage, outputImage, depthRadius, bias, alpha, beta) { var program = new lrn_grad_gpu_1.LRNGradProgram(inputImage.shape, depthRadius, bias, alpha, beta); return this.compileAndRun(program, [inputImage, outputImage, dy]); }; MathBackendWebGL.prototype.tile = function (x, reps) { var program = new tile_gpu_1.TileProgram(x.shape, reps); return this.compileAndRun(program, [x]); }; MathBackendWebGL.prototype.pad = function (x, paddings, constantValue) { var program = environment_1.ENV.get('WEBGL_PACK_ARRAY_OPERATIONS') ? new pad_packed_gpu_1.PadPackedProgram(x.shape, paddings, constantValue) : new pad_gpu_1.PadProgram(x.shape, paddings, constantValue); return this.compileAndRun(program, [x]); }; MathBackendWebGL.prototype.transpose = function (x, perm) { var program = new transpose_gpu_1.TransposeProgram(x.shape, perm); return this.compileAndRun(program, [x]); }; MathBackendWebGL.prototype.gather = function (x, indices, axis) { var program = new gather_gpu_1.GatherProgram(x.shape, indices.size, axis); return this.compileAndRun(program, [x, indices]); }; MathBackendWebGL.prototype.batchToSpaceND = function (x, blockShape, crops) { util.assert(x.rank <= 4, 'batchToSpaceND for rank > 4 with a WebGL backend not implemented yet'); var prod = blockShape.reduce(function (a, b) { return a * b; }); var reshaped = array_ops_util.getReshaped(x.shape, blockShape, prod); var permuted = array_ops_util.getPermuted(reshaped.length, blockShape.length); var reshapedPermuted = array_ops_util.getReshapedPermuted(x.shape, blockShape, prod); var sliceBeginCoords = array_ops_util.getSliceBeginCoords(crops, blockShape.length); var sliceSize = array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length); return x.reshape(reshaped) .transpose(permuted) .reshape(reshapedPermuted) .slice(sliceBeginCoords, sliceSize); }; MathBackendWebGL.prototype.spaceToBatchND = function (x, blockShape, paddings) { util.assert(x.rank <= 4, 'spaceToBatchND for rank > 4 with a WebGL backend not implemented yet'); var prod = blockShape.reduce(function (a, b) { return a * b; }); var completePaddings = [[0, 0]]; completePaddings.push.apply(completePaddings, paddings); for (var i = 1 + blockShape.length; i < x.shape.length; ++i) { completePaddings.push([0, 0]); } var paddedX = x.pad(completePaddings); var reshapedPaddedShape = array_ops_util.getReshaped(paddedX.shape, blockShape, prod, false); var permutedReshapedPaddedPermutation = array_ops_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false); var flattenShape = array_ops_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false); return paddedX.reshape(reshapedPaddedShape) .transpose(permutedReshapedPaddedPermutation) .reshape(flattenShape); }; MathBackendWebGL.prototype.reduce = function (x, reduceType, dtype) { var batchSize = x.shape[0]; var inSize = x.shape[1]; var windowSize = reduce_util.computeOptimalWindowSize(inSize); var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize }; var program = new reduce_gpu_1.ReduceProgram(reduceInfo, reduceType); var _a = program.outputShape, rows = _a[0], cols = _a[1]; var output = this.makeOutputArray([rows, cols], dtype); this.compileAndRun(program, [x], output); if (output.shape[1] === 1) { return output; } return this.reduce(output, reduceType, dtype); }; MathBackendWebGL.prototype.argReduce = function (x, reduceType, bestIndicesA) { if (bestIndicesA === void 0) { bestIndicesA = null; } var batchSize = x.shape[0]; var inSize = x.shape[1]; if (bestIndicesA != null) { batchSize = bestIndicesA.shape[0]; inSize = bestIndicesA.shape[1]; } var windowSize = reduce_util.computeOptimalWindowSize(inSize); var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize }; var program = new argminmax_gpu_1.ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null); var _a = program.outputShape, rows = _a[0], cols = _a[1]; var output = this.makeOutputArray([rows, cols], 'int32'); var inputs = [x]; if (bestIndicesA != null) { inputs.push(bestIndicesA); } this.compileAndRun(program, inputs, output); if (output.shape[1] === 1) { return output; } return this.argReduce(x, reduceType, output); }; MathBackendWebGL.prototype.sum = function (x, axes) { axis_util.assertAxesAreInnerMostDims('sum', axes, x.rank); var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; var inSize = util.sizeFromShape(reduceShape); var a2D = x.as2D(-1, inSize); var outputDType = types_1.sumOutType(x.dtype); return this.reduce(a2D, 'sum', outputDType).reshape(outShape); }; MathBackendWebGL.prototype.prod = function (x, axes) { var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; var inSize = util.sizeFromShape(reduceShape); var a2D = x.as2D(-1, inSize); var outputDType = types_1.sumOutType(x.dtype); return this.reduce(a2D, 'prod', outputDType).reshape(outShape); }; MathBackendWebGL.prototype.unsortedSegmentSum = function (x, segmentIds, numSegments) { var axis = 0; var permutation = axis_util.getAxesPermutation([axis], x.rank); var permutedX = x; if (permutation != null) { permutedX = x.transpose(permutation); axis = axis_util.getInnerMostAxes(1, x.rank)[0]; } var outShape = segment_util.computeOutShape(permutedX.shape, axis, numSegments); var inSize = util.sizeFromShape([permutedX.shape[axis]]); var a2D = permutedX.as2D(-1, inSize); var outputDType = types_1.sumOutType(x.dtype); var result = this.segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments) .reshape(outShape); if (permutation != null) { result = result.transpose(axis_util.getUndoAxesPermutation(permutation)); } return result; }; MathBackendWebGL.prototype.segOpCompute = function (x, segOpType, segmentIds, dtype, numSegments) { var batchSize = x.shape[0]; var inSize = x.shape[1]; var windowSize = segment_util.segOpComputeOptimalWindowSize(inSize, numSegments); var segOpInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize, numSegments: numSegments }; var program = new segment_gpu_1.SegmentOpProgram(segOpInfo, segOpType); var _a = program.outputShape, rows = _a[0], cols = _a[1]; var output = this.makeOutputArray([rows, cols], dtype); this.compileAndRun(program, [x, segmentIds], output); if (output.shape[1] === numSegments) { return output; } segmentIds = tensor_ops_1.range(0, numSegments).tile([inSize / windowSize]); return this.segOpCompute(output, segOpType, segmentIds, dtype, numSegments); }; MathBackendWebGL.prototype.argMin = function (x, axis) { var axes = [axis]; axis_util.assertAxesAreInnerMostDims('argMin', axes, x.rank); var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; var inSize = util.sizeFromShape(reduceShape); var a2D = x.as2D(-1, inSize); return this.argReduce(a2D, 'min').reshape(outShape); }; MathBackendWebGL.prototype.argMax = function (x, axis) { var axes = [axis]; axis_util.assertAxesAreInnerMostDims('argMax', axes, x.rank); var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; var inSize = util.sizeFromShape(reduceShape); var a2D = x.as2D(-1, inSize); return this.argReduce(a2D, 'max').reshape(outShape); }; MathBackendWebGL.prototype.cumsum = function (x, axis, exclusive, reverse) { if (axis !== x.rank - 1) { throw new Error("WebGL cumsum shader expects an inner-most axis=" + (x.rank - 1) + " " + ("but got axis=" + axis)); } var program = new cumsum_gpu_1.CumSumProgram(x.shape, exclusive, reverse); return this.compileAndRun(program, [x]); }; MathBackendWebGL.prototype.equal = function (a, b) { var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.EQUAL, a.shape, b.shape); var output = this.makeOutputArray(program.outputShape, 'bool'); return this.compileAndRun(program, [a, b], output); }; MathBackendWebGL.prototype.notEqual = function (a, b) { var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.NOT_EQUAL, a.shape, b.shape); var output = this.makeOutputArray(program.outputShape, 'bool'); return this.compileAndRun(program, [a, b], output); }; MathBackendWebGL.prototype.less = function (a, b) { if (this.shouldExecuteOnCPU([a, b])) { return this.cpuBackend.less(a, b); } var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.LESS, a.shape, b.shape); var output = this.makeOutputArray(program.outputShape, 'bool'); return this.compileAndRun(program, [a, b], output); }; MathBackendWebGL.prototype.lessEqual = function (a, b) { var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.LESS_EQUAL, a.shape, b.shape); var output = this.makeOutputArray(program.outputShape, 'bool'); return this.compileAndRun(program, [a, b], output); }; MathBackendWebGL.prototype.greater = function (a, b) { if (this.shouldExecuteOnCPU([a, b])) { return this.cpuBackend.greater(a, b); } var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.GREATER, a.shape, b.shape); var output = this.makeOutputArray(program.outputShape, 'bool'); return this.compileAndRun(program, [a, b], output); }; MathBackendWebGL.prototype.greaterEqual = function (a, b) { var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.GREATER_EQUAL, a.shape, b.shape); var output = this.makeOutputArray(program.outputShape, 'bool'); return this.compileAndRun(program, [a, b], output); }; MathBackendWebGL.prototype.logicalNot = function (x) { var program = new unaryop_gpu_1.UnaryOpProgram(x.shape, unary_op.LOGICAL_NOT); return this.compileAndRun(program, [x]); }; MathBackendWebGL.prototype.logicalAnd = function (a, b) { var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.LOGICAL_AND, a.shape, b.shape); var output = this.makeOutputArray(program.outputShape, 'bool'); return this.compileAndRun(program, [a, b], output); }; MathBackendWebGL.prototype.logicalOr = function (a, b) { var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.LOGICAL_OR, a.shape, b.shape); var output = this.makeOutputArray(program.outputShape, 'bool'); return this.compileAndRun(program, [a, b], output); }; MathBackendWebGL.prototype.select = function (condition, a, b) { var program = new select_gpu_1.SelectProgram(condition.rank, a.shape, a.rank); var output = this.makeOutputArray(program.outputShape, types_1.upcastType(a.dtype, b.dtype)); return this.compileAndRun(program, [condition, a, b], output); }; MathBackendWebGL.prototype.where = function (condition) { log_1.warn('tf.where() in webgl locks the UI thread. ' + 'Call tf.whereAsync() instead'); var condVals = condition.dataSync(); return where_impl_1.whereImpl(condition.shape, condVals); }; MathBackendWebGL.prototype.topk = function (x, k, sorted) { var xVals = x.dataSync(); return topk_impl_1.topkImpl(xVals, x.shape, x.dtype, k, sorted); }; MathBackendWebGL.prototype.min = function (x, axes) { axis_util.assertAxesAreInnerMostDims('min', axes, x.rank); var _a = axis_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; var inSize =