UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

437 lines 58.1 kB
/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import './flags_wasm'; import { DataStorage, deprecationWarn, engine, env, KernelBackend, util } from '@tensorflow/tfjs-core'; import * as wasmFactoryThreadedSimd_import from '../wasm-out/tfjs-backend-wasm-threaded-simd.js'; // @ts-ignore import { wasmWorkerContents } from '../wasm-out/tfjs-backend-wasm-threaded-simd.worker.js'; import * as wasmFactory_import from '../wasm-out/tfjs-backend-wasm.js'; // This workaround is required for importing in Node.js without using // the node bundle (for testing). This would not be necessary if we // flipped esModuleInterop to true, but we likely can't do that since // google3 does not use it. const wasmFactoryThreadedSimd = (wasmFactoryThreadedSimd_import.default || wasmFactoryThreadedSimd_import); const wasmFactory = (wasmFactory_import.default || wasmFactory_import); export class BackendWasm extends KernelBackend { constructor(wasm) { super(); this.wasm = wasm; // 0 is reserved for null data ids. this.dataIdNextNumber = 1; this.wasm.tfjs.initWithThreadsCount(threadsCount); actualThreadsCount = this.wasm.tfjs.getThreadsCount(); this.dataIdMap = new DataStorage(this, engine()); } write(values, shape, dtype) { const dataId = { id: this.dataIdNextNumber++ }; this.move(dataId, values, shape, dtype, 1); return dataId; } numDataIds() { return this.dataIdMap.numDataIds(); } async time(f) { const start = util.now(); f(); const kernelMs = util.now() - start; return { kernelMs }; } move(dataId, values, shape, dtype, refCount) { const id = this.dataIdNextNumber++; if (dtype === 'string') { const stringBytes = values; this.dataIdMap.set(dataId, { id, stringBytes, shape, dtype, memoryOffset: null, refCount }); return; } const size = util.sizeFromShape(shape); const numBytes = size * util.bytesPerElement(dtype); // `>>> 0` is needed for above 2GB allocations because wasm._malloc returns // a signed int32 instead of an unsigned int32. // https://v8.dev/blog/4gb-wasm-memory const memoryOffset = this.wasm._malloc(numBytes) >>> 0; this.dataIdMap.set(dataId, { id, memoryOffset, shape, dtype, refCount }); this.wasm.tfjs.registerTensor(id, size, memoryOffset); if (values != null) { this.wasm.HEAPU8.set(new Uint8Array(values.buffer, values.byteOffset, numBytes), memoryOffset); } } async read(dataId) { return this.readSync(dataId); } readSync(dataId, start, end) { const { memoryOffset, dtype, shape, stringBytes } = this.dataIdMap.get(dataId); if (dtype === 'string') { // Slice all elements. if ((start == null || start === 0) && (end == null || end >= stringBytes.length)) { return stringBytes; } return stringBytes.slice(start, end); } start = start || 0; end = end || util.sizeFromShape(shape); const bytesPerElement = util.bytesPerElement(dtype); const bytes = this.wasm.HEAPU8.slice(memoryOffset + start * bytesPerElement, memoryOffset + end * bytesPerElement); return typedArrayFromBuffer(bytes.buffer, dtype); } /** * Dispose the memory if the dataId has 0 refCount. Return true if the memory * is released, false otherwise. * @param dataId * @oaram force Optional, remove the data regardless of refCount */ disposeData(dataId, force = false) { if (this.dataIdMap.has(dataId)) { const data = this.dataIdMap.get(dataId); data.refCount--; if (!force && data.refCount > 0) { return false; } this.wasm._free(data.memoryOffset); this.wasm.tfjs.disposeData(data.id); this.dataIdMap.delete(dataId); } return true; } /** Return refCount of a `TensorData`. */ refCount(dataId) { if (this.dataIdMap.has(dataId)) { const tensorData = this.dataIdMap.get(dataId); return tensorData.refCount; } return 0; } incRef(dataId) { const data = this.dataIdMap.get(dataId); if (data != null) { data.refCount++; } } floatPrecision() { return 32; } // Returns the memory offset of a tensor. Useful for debugging and unit // testing. getMemoryOffset(dataId) { return this.dataIdMap.get(dataId).memoryOffset; } dispose() { this.wasm.tfjs.dispose(); if ('PThread' in this.wasm) { this.wasm.PThread.terminateAllThreads(); } this.wasm = null; } memory() { return { unreliable: false }; } /** * Make a tensor info for the output of an op. If `memoryOffset` is not * present, this method allocates memory on the WASM heap. If `memoryOffset` * is present, the memory was allocated elsewhere (in c++) and we just record * the pointer where that memory lives. */ makeOutput(shape, dtype, memoryOffset, values) { let dataId; if (memoryOffset == null) { dataId = this.write(values !== null && values !== void 0 ? values : null, shape, dtype); } else { const id = this.dataIdNextNumber++; dataId = { id }; this.dataIdMap.set(dataId, { id, memoryOffset, shape, dtype, refCount: 1 }); const size = util.sizeFromShape(shape); this.wasm.tfjs.registerTensor(id, size, memoryOffset); } return { dataId, shape, dtype }; } typedArrayFromHeap({ shape, dtype, dataId }) { const buffer = this.wasm.HEAPU8.buffer; const { memoryOffset } = this.dataIdMap.get(dataId); const size = util.sizeFromShape(shape); switch (dtype) { case 'float32': return new Float32Array(buffer, memoryOffset, size); case 'int32': return new Int32Array(buffer, memoryOffset, size); case 'bool': return new Uint8Array(buffer, memoryOffset, size); default: throw new Error(`Unknown dtype ${dtype}`); } } } function createInstantiateWasmFunc(path) { // this will be replace by rollup plugin patchWechatWebAssembly in // minprogram's output. // tslint:disable-next-line:no-any return (imports, callback) => { util.fetch(path, { credentials: 'same-origin' }).then((response) => { if (!response['ok']) { imports.env.a(`failed to load wasm binary file at '${path}'`); } response.arrayBuffer().then(binary => { WebAssembly.instantiate(binary, imports).then(output => { callback(output.instance, output.module); }); }); }); return {}; }; } /** * Returns the path of the WASM binary. * @param simdSupported whether SIMD is supported * @param threadsSupported whether multithreading is supported * @param wasmModuleFolder the directory containing the WASM binaries. */ function getPathToWasmBinary(simdSupported, threadsSupported, wasmModuleFolder) { if (wasmPath != null) { // If wasmPath is defined, the user has supplied a full path to // the vanilla .wasm binary. return wasmPath; } let path = 'tfjs-backend-wasm.wasm'; if (simdSupported && threadsSupported) { path = 'tfjs-backend-wasm-threaded-simd.wasm'; } else if (simdSupported) { path = 'tfjs-backend-wasm-simd.wasm'; } if (wasmFileMap != null) { if (wasmFileMap[path] != null) { return wasmFileMap[path]; } } return wasmModuleFolder + path; } /** * Initializes the wasm module and creates the js <--> wasm bridge. * * NOTE: We wrap the wasm module in a object with property 'wasm' instead of * returning Promise<BackendWasmModule> to avoid freezing Chrome (last tested * in Chrome 76). */ export async function init() { const [simdSupported, threadsSupported] = await Promise.all([ env().getAsync('WASM_HAS_SIMD_SUPPORT'), env().getAsync('WASM_HAS_MULTITHREAD_SUPPORT') ]); return new Promise((resolve, reject) => { const factoryConfig = {}; /** * This function overrides the Emscripten module locateFile utility. * @param path The relative path to the file that needs to be loaded. * @param prefix The path to the main JavaScript file's directory. */ factoryConfig.locateFile = (path, prefix) => { if (path.endsWith('.worker.js')) { // Escape '\n' because Blob will turn it into a newline. // There should be a setting for this, but 'endings: "native"' does // not seem to work. const response = wasmWorkerContents.replace(/\n/g, '\\n'); const blob = new Blob([response], { type: 'application/javascript' }); return URL.createObjectURL(blob); } if (path.endsWith('.wasm')) { return getPathToWasmBinary(simdSupported, threadsSupported, wasmPathPrefix != null ? wasmPathPrefix : prefix); } return prefix + path; }; // Use the instantiateWasm override when system fetch is not available. // Reference: // https://github.com/emscripten-core/emscripten/blob/2bca083cbbd5a4133db61fbd74d04f7feecfa907/tests/manual_wasm_instantiate.html#L170 if (customFetch) { factoryConfig.instantiateWasm = createInstantiateWasmFunc(getPathToWasmBinary(simdSupported, threadsSupported, wasmPathPrefix != null ? wasmPathPrefix : '')); } let initialized = false; factoryConfig.onAbort = () => { if (initialized) { // Emscripten already called console.warn so no need to double log. return; } if (initAborted) { // Emscripten calls `onAbort` twice, resulting in double error // messages. return; } initAborted = true; const rejectMsg = 'Make sure the server can serve the `.wasm` file relative to the ' + 'bundled js file. For more details see https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-wasm/README.md#using-bundlers'; reject({ message: rejectMsg }); }; let wasm; // If `wasmPath` has been defined we must initialize the vanilla module. if (threadsSupported && simdSupported && wasmPath == null) { factoryConfig.mainScriptUrlOrBlob = new Blob([`var WasmBackendModuleThreadedSimd = ` + wasmFactoryThreadedSimd.toString()], { type: 'text/javascript' }); wasm = wasmFactoryThreadedSimd(factoryConfig); } else { // The wasmFactory works for both vanilla and SIMD binaries. wasm = wasmFactory(factoryConfig); } // The `wasm` promise will resolve to the WASM module created by // the factory, but it might have had errors during creation. Most // errors are caught by the onAbort callback defined above. // However, some errors, such as those occurring from a // failed fetch, result in this promise being rejected. These are // caught and re-rejected below. wasm.then((module) => { initialized = true; initAborted = false; const voidReturnType = null; // Using the tfjs namespace to avoid conflict with emscripten's API. module.tfjs = { init: module.cwrap('init', null, []), initWithThreadsCount: module.cwrap('init_with_threads_count', null, ['number']), getThreadsCount: module.cwrap('get_threads_count', 'number', []), registerTensor: module.cwrap('register_tensor', null, [ 'number', 'number', 'number', // memoryOffset ]), disposeData: module.cwrap('dispose_data', voidReturnType, ['number']), dispose: module.cwrap('dispose', voidReturnType, []), }; resolve({ wasm: module }); }) .catch(reject); }); } function typedArrayFromBuffer(buffer, dtype) { switch (dtype) { case 'float32': return new Float32Array(buffer); case 'int32': return new Int32Array(buffer); case 'bool': return new Uint8Array(buffer); default: throw new Error(`Unknown dtype ${dtype}`); } } const wasmBinaryNames = [ 'tfjs-backend-wasm.wasm', 'tfjs-backend-wasm-simd.wasm', 'tfjs-backend-wasm-threaded-simd.wasm' ]; let wasmPath = null; let wasmPathPrefix = null; let wasmFileMap = {}; let initAborted = false; let customFetch = false; /** * @deprecated Use `setWasmPaths` instead. * Sets the path to the `.wasm` file which will be fetched when the wasm * backend is initialized. See * https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-wasm/README.md#using-bundlers * for more details. * @param path wasm file path or url * @param usePlatformFetch optional boolean to use platform fetch to download * the wasm file, default to false. * * @doc {heading: 'Environment', namespace: 'wasm'} */ export function setWasmPath(path, usePlatformFetch = false) { deprecationWarn('setWasmPath has been deprecated in favor of setWasmPaths and' + ' will be removed in a future release.'); if (initAborted) { throw new Error('The WASM backend was already initialized. Make sure you call ' + '`setWasmPath()` before you call `tf.setBackend()` or `tf.ready()`'); } wasmPath = path; customFetch = usePlatformFetch; } /** * Configures the locations of the WASM binaries. * * ```js * setWasmPaths({ * 'tfjs-backend-wasm.wasm': 'renamed.wasm', * 'tfjs-backend-wasm-simd.wasm': 'renamed-simd.wasm', * 'tfjs-backend-wasm-threaded-simd.wasm': 'renamed-threaded-simd.wasm' * }); * tf.setBackend('wasm'); * ``` * * @param prefixOrFileMap This can be either a string or object: * - (string) The path to the directory where the WASM binaries are located. * Note that this prefix will be used to load each binary (vanilla, * SIMD-enabled, threading-enabled, etc.). * - (object) Mapping from names of WASM binaries to custom * full paths specifying the locations of those binaries. This is useful if * your WASM binaries are not all located in the same directory, or if your * WASM binaries have been renamed. * @param usePlatformFetch optional boolean to use platform fetch to download * the wasm file, default to false. * * @doc {heading: 'Environment', namespace: 'wasm'} */ export function setWasmPaths(prefixOrFileMap, usePlatformFetch = false) { if (initAborted) { throw new Error('The WASM backend was already initialized. Make sure you call ' + '`setWasmPaths()` before you call `tf.setBackend()` or ' + '`tf.ready()`'); } if (typeof prefixOrFileMap === 'string') { wasmPathPrefix = prefixOrFileMap; } else { wasmFileMap = prefixOrFileMap; const missingPaths = wasmBinaryNames.filter(name => wasmFileMap[name] == null); if (missingPaths.length > 0) { throw new Error(`There were no entries found for the following binaries: ` + `${missingPaths.join(',')}. Please either call setWasmPaths with a ` + `map providing a path for each binary, or with a string indicating ` + `the directory where all the binaries can be found.`); } } customFetch = usePlatformFetch; } /** Used in unit tests. */ export function resetWasmPath() { wasmPath = null; wasmPathPrefix = null; wasmFileMap = {}; customFetch = false; initAborted = false; } let threadsCount = -1; let actualThreadsCount = -1; /** * Sets the number of threads that will be used by XNNPACK to create * threadpool (default to the number of logical CPU cores). * * This must be called before calling `tf.setBackend('wasm')`. */ export function setThreadsCount(numThreads) { threadsCount = numThreads; } /** * Gets the actual threads count that is used by XNNPACK. * * It is set after the backend is intialized. */ export function getThreadsCount() { if (actualThreadsCount === -1) { throw new Error(`WASM backend not initialized.`); } return actualThreadsCount; } //# sourceMappingURL=data:application/json;base64,