@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
437 lines • 58.1 kB
JavaScript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import './flags_wasm';
import { DataStorage, deprecationWarn, engine, env, KernelBackend, util } from '@tensorflow/tfjs-core';
import * as wasmFactoryThreadedSimd_import from '../wasm-out/tfjs-backend-wasm-threaded-simd.js';
// @ts-ignore
import { wasmWorkerContents } from '../wasm-out/tfjs-backend-wasm-threaded-simd.worker.js';
import * as wasmFactory_import from '../wasm-out/tfjs-backend-wasm.js';
// This workaround is required for importing in Node.js without using
// the node bundle (for testing). This would not be necessary if we
// flipped esModuleInterop to true, but we likely can't do that since
// google3 does not use it.
const wasmFactoryThreadedSimd = (wasmFactoryThreadedSimd_import.default ||
wasmFactoryThreadedSimd_import);
const wasmFactory = (wasmFactory_import.default || wasmFactory_import);
export class BackendWasm extends KernelBackend {
constructor(wasm) {
super();
this.wasm = wasm;
// 0 is reserved for null data ids.
this.dataIdNextNumber = 1;
this.wasm.tfjs.initWithThreadsCount(threadsCount);
actualThreadsCount = this.wasm.tfjs.getThreadsCount();
this.dataIdMap = new DataStorage(this, engine());
}
write(values, shape, dtype) {
const dataId = { id: this.dataIdNextNumber++ };
this.move(dataId, values, shape, dtype, 1);
return dataId;
}
numDataIds() {
return this.dataIdMap.numDataIds();
}
async time(f) {
const start = util.now();
f();
const kernelMs = util.now() - start;
return { kernelMs };
}
move(dataId, values, shape, dtype, refCount) {
const id = this.dataIdNextNumber++;
if (dtype === 'string') {
const stringBytes = values;
this.dataIdMap.set(dataId, { id, stringBytes, shape, dtype, memoryOffset: null, refCount });
return;
}
const size = util.sizeFromShape(shape);
const numBytes = size * util.bytesPerElement(dtype);
// `>>> 0` is needed for above 2GB allocations because wasm._malloc returns
// a signed int32 instead of an unsigned int32.
// https://v8.dev/blog/4gb-wasm-memory
const memoryOffset = this.wasm._malloc(numBytes) >>> 0;
this.dataIdMap.set(dataId, { id, memoryOffset, shape, dtype, refCount });
this.wasm.tfjs.registerTensor(id, size, memoryOffset);
if (values != null) {
this.wasm.HEAPU8.set(new Uint8Array(values.buffer, values.byteOffset, numBytes), memoryOffset);
}
}
async read(dataId) {
return this.readSync(dataId);
}
readSync(dataId, start, end) {
const { memoryOffset, dtype, shape, stringBytes } = this.dataIdMap.get(dataId);
if (dtype === 'string') {
// Slice all elements.
if ((start == null || start === 0) &&
(end == null || end >= stringBytes.length)) {
return stringBytes;
}
return stringBytes.slice(start, end);
}
start = start || 0;
end = end || util.sizeFromShape(shape);
const bytesPerElement = util.bytesPerElement(dtype);
const bytes = this.wasm.HEAPU8.slice(memoryOffset + start * bytesPerElement, memoryOffset + end * bytesPerElement);
return typedArrayFromBuffer(bytes.buffer, dtype);
}
/**
* Dispose the memory if the dataId has 0 refCount. Return true if the memory
* is released, false otherwise.
* @param dataId
* @oaram force Optional, remove the data regardless of refCount
*/
disposeData(dataId, force = false) {
if (this.dataIdMap.has(dataId)) {
const data = this.dataIdMap.get(dataId);
data.refCount--;
if (!force && data.refCount > 0) {
return false;
}
this.wasm._free(data.memoryOffset);
this.wasm.tfjs.disposeData(data.id);
this.dataIdMap.delete(dataId);
}
return true;
}
/** Return refCount of a `TensorData`. */
refCount(dataId) {
if (this.dataIdMap.has(dataId)) {
const tensorData = this.dataIdMap.get(dataId);
return tensorData.refCount;
}
return 0;
}
incRef(dataId) {
const data = this.dataIdMap.get(dataId);
if (data != null) {
data.refCount++;
}
}
floatPrecision() {
return 32;
}
// Returns the memory offset of a tensor. Useful for debugging and unit
// testing.
getMemoryOffset(dataId) {
return this.dataIdMap.get(dataId).memoryOffset;
}
dispose() {
this.wasm.tfjs.dispose();
if ('PThread' in this.wasm) {
this.wasm.PThread.terminateAllThreads();
}
this.wasm = null;
}
memory() {
return { unreliable: false };
}
/**
* Make a tensor info for the output of an op. If `memoryOffset` is not
* present, this method allocates memory on the WASM heap. If `memoryOffset`
* is present, the memory was allocated elsewhere (in c++) and we just record
* the pointer where that memory lives.
*/
makeOutput(shape, dtype, memoryOffset, values) {
let dataId;
if (memoryOffset == null) {
dataId = this.write(values !== null && values !== void 0 ? values : null, shape, dtype);
}
else {
const id = this.dataIdNextNumber++;
dataId = { id };
this.dataIdMap.set(dataId, { id, memoryOffset, shape, dtype, refCount: 1 });
const size = util.sizeFromShape(shape);
this.wasm.tfjs.registerTensor(id, size, memoryOffset);
}
return { dataId, shape, dtype };
}
typedArrayFromHeap({ shape, dtype, dataId }) {
const buffer = this.wasm.HEAPU8.buffer;
const { memoryOffset } = this.dataIdMap.get(dataId);
const size = util.sizeFromShape(shape);
switch (dtype) {
case 'float32':
return new Float32Array(buffer, memoryOffset, size);
case 'int32':
return new Int32Array(buffer, memoryOffset, size);
case 'bool':
return new Uint8Array(buffer, memoryOffset, size);
default:
throw new Error(`Unknown dtype ${dtype}`);
}
}
}
function createInstantiateWasmFunc(path) {
// this will be replace by rollup plugin patchWechatWebAssembly in
// minprogram's output.
// tslint:disable-next-line:no-any
return (imports, callback) => {
util.fetch(path, { credentials: 'same-origin' }).then((response) => {
if (!response['ok']) {
imports.env.a(`failed to load wasm binary file at '${path}'`);
}
response.arrayBuffer().then(binary => {
WebAssembly.instantiate(binary, imports).then(output => {
callback(output.instance, output.module);
});
});
});
return {};
};
}
/**
* Returns the path of the WASM binary.
* @param simdSupported whether SIMD is supported
* @param threadsSupported whether multithreading is supported
* @param wasmModuleFolder the directory containing the WASM binaries.
*/
function getPathToWasmBinary(simdSupported, threadsSupported, wasmModuleFolder) {
if (wasmPath != null) {
// If wasmPath is defined, the user has supplied a full path to
// the vanilla .wasm binary.
return wasmPath;
}
let path = 'tfjs-backend-wasm.wasm';
if (simdSupported && threadsSupported) {
path = 'tfjs-backend-wasm-threaded-simd.wasm';
}
else if (simdSupported) {
path = 'tfjs-backend-wasm-simd.wasm';
}
if (wasmFileMap != null) {
if (wasmFileMap[path] != null) {
return wasmFileMap[path];
}
}
return wasmModuleFolder + path;
}
/**
* Initializes the wasm module and creates the js <--> wasm bridge.
*
* NOTE: We wrap the wasm module in a object with property 'wasm' instead of
* returning Promise<BackendWasmModule> to avoid freezing Chrome (last tested
* in Chrome 76).
*/
export async function init() {
const [simdSupported, threadsSupported] = await Promise.all([
env().getAsync('WASM_HAS_SIMD_SUPPORT'),
env().getAsync('WASM_HAS_MULTITHREAD_SUPPORT')
]);
return new Promise((resolve, reject) => {
const factoryConfig = {};
/**
* This function overrides the Emscripten module locateFile utility.
* @param path The relative path to the file that needs to be loaded.
* @param prefix The path to the main JavaScript file's directory.
*/
factoryConfig.locateFile = (path, prefix) => {
if (path.endsWith('.worker.js')) {
// Escape '\n' because Blob will turn it into a newline.
// There should be a setting for this, but 'endings: "native"' does
// not seem to work.
const response = wasmWorkerContents.replace(/\n/g, '\\n');
const blob = new Blob([response], { type: 'application/javascript' });
return URL.createObjectURL(blob);
}
if (path.endsWith('.wasm')) {
return getPathToWasmBinary(simdSupported, threadsSupported, wasmPathPrefix != null ? wasmPathPrefix : prefix);
}
return prefix + path;
};
// Use the instantiateWasm override when system fetch is not available.
// Reference:
// https://github.com/emscripten-core/emscripten/blob/2bca083cbbd5a4133db61fbd74d04f7feecfa907/tests/manual_wasm_instantiate.html#L170
if (customFetch) {
factoryConfig.instantiateWasm =
createInstantiateWasmFunc(getPathToWasmBinary(simdSupported, threadsSupported, wasmPathPrefix != null ? wasmPathPrefix : ''));
}
let initialized = false;
factoryConfig.onAbort = () => {
if (initialized) {
// Emscripten already called console.warn so no need to double log.
return;
}
if (initAborted) {
// Emscripten calls `onAbort` twice, resulting in double error
// messages.
return;
}
initAborted = true;
const rejectMsg = 'Make sure the server can serve the `.wasm` file relative to the ' +
'bundled js file. For more details see https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-wasm/README.md#using-bundlers';
reject({ message: rejectMsg });
};
let wasm;
// If `wasmPath` has been defined we must initialize the vanilla module.
if (threadsSupported && simdSupported && wasmPath == null) {
factoryConfig.mainScriptUrlOrBlob = new Blob([`var WasmBackendModuleThreadedSimd = ` +
wasmFactoryThreadedSimd.toString()], { type: 'text/javascript' });
wasm = wasmFactoryThreadedSimd(factoryConfig);
}
else {
// The wasmFactory works for both vanilla and SIMD binaries.
wasm = wasmFactory(factoryConfig);
}
// The `wasm` promise will resolve to the WASM module created by
// the factory, but it might have had errors during creation. Most
// errors are caught by the onAbort callback defined above.
// However, some errors, such as those occurring from a
// failed fetch, result in this promise being rejected. These are
// caught and re-rejected below.
wasm.then((module) => {
initialized = true;
initAborted = false;
const voidReturnType = null;
// Using the tfjs namespace to avoid conflict with emscripten's API.
module.tfjs = {
init: module.cwrap('init', null, []),
initWithThreadsCount: module.cwrap('init_with_threads_count', null, ['number']),
getThreadsCount: module.cwrap('get_threads_count', 'number', []),
registerTensor: module.cwrap('register_tensor', null, [
'number',
'number',
'number', // memoryOffset
]),
disposeData: module.cwrap('dispose_data', voidReturnType, ['number']),
dispose: module.cwrap('dispose', voidReturnType, []),
};
resolve({ wasm: module });
})
.catch(reject);
});
}
function typedArrayFromBuffer(buffer, dtype) {
switch (dtype) {
case 'float32':
return new Float32Array(buffer);
case 'int32':
return new Int32Array(buffer);
case 'bool':
return new Uint8Array(buffer);
default:
throw new Error(`Unknown dtype ${dtype}`);
}
}
const wasmBinaryNames = [
'tfjs-backend-wasm.wasm', 'tfjs-backend-wasm-simd.wasm',
'tfjs-backend-wasm-threaded-simd.wasm'
];
let wasmPath = null;
let wasmPathPrefix = null;
let wasmFileMap = {};
let initAborted = false;
let customFetch = false;
/**
* @deprecated Use `setWasmPaths` instead.
* Sets the path to the `.wasm` file which will be fetched when the wasm
* backend is initialized. See
* https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-wasm/README.md#using-bundlers
* for more details.
* @param path wasm file path or url
* @param usePlatformFetch optional boolean to use platform fetch to download
* the wasm file, default to false.
*
* @doc {heading: 'Environment', namespace: 'wasm'}
*/
export function setWasmPath(path, usePlatformFetch = false) {
deprecationWarn('setWasmPath has been deprecated in favor of setWasmPaths and' +
' will be removed in a future release.');
if (initAborted) {
throw new Error('The WASM backend was already initialized. Make sure you call ' +
'`setWasmPath()` before you call `tf.setBackend()` or `tf.ready()`');
}
wasmPath = path;
customFetch = usePlatformFetch;
}
/**
* Configures the locations of the WASM binaries.
*
* ```js
* setWasmPaths({
* 'tfjs-backend-wasm.wasm': 'renamed.wasm',
* 'tfjs-backend-wasm-simd.wasm': 'renamed-simd.wasm',
* 'tfjs-backend-wasm-threaded-simd.wasm': 'renamed-threaded-simd.wasm'
* });
* tf.setBackend('wasm');
* ```
*
* @param prefixOrFileMap This can be either a string or object:
* - (string) The path to the directory where the WASM binaries are located.
* Note that this prefix will be used to load each binary (vanilla,
* SIMD-enabled, threading-enabled, etc.).
* - (object) Mapping from names of WASM binaries to custom
* full paths specifying the locations of those binaries. This is useful if
* your WASM binaries are not all located in the same directory, or if your
* WASM binaries have been renamed.
* @param usePlatformFetch optional boolean to use platform fetch to download
* the wasm file, default to false.
*
* @doc {heading: 'Environment', namespace: 'wasm'}
*/
export function setWasmPaths(prefixOrFileMap, usePlatformFetch = false) {
if (initAborted) {
throw new Error('The WASM backend was already initialized. Make sure you call ' +
'`setWasmPaths()` before you call `tf.setBackend()` or ' +
'`tf.ready()`');
}
if (typeof prefixOrFileMap === 'string') {
wasmPathPrefix = prefixOrFileMap;
}
else {
wasmFileMap = prefixOrFileMap;
const missingPaths = wasmBinaryNames.filter(name => wasmFileMap[name] == null);
if (missingPaths.length > 0) {
throw new Error(`There were no entries found for the following binaries: ` +
`${missingPaths.join(',')}. Please either call setWasmPaths with a ` +
`map providing a path for each binary, or with a string indicating ` +
`the directory where all the binaries can be found.`);
}
}
customFetch = usePlatformFetch;
}
/** Used in unit tests. */
export function resetWasmPath() {
wasmPath = null;
wasmPathPrefix = null;
wasmFileMap = {};
customFetch = false;
initAborted = false;
}
let threadsCount = -1;
let actualThreadsCount = -1;
/**
* Sets the number of threads that will be used by XNNPACK to create
* threadpool (default to the number of logical CPU cores).
*
* This must be called before calling `tf.setBackend('wasm')`.
*/
export function setThreadsCount(numThreads) {
threadsCount = numThreads;
}
/**
* Gets the actual threads count that is used by XNNPACK.
*
* It is set after the backend is intialized.
*/
export function getThreadsCount() {
if (actualThreadsCount === -1) {
throw new Error(`WASM backend not initialized.`);
}
return actualThreadsCount;
}
//# sourceMappingURL=data:application/json;base64,