UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

63 lines 11.2 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, MaxPoolWithArgmax, util } from '@tensorflow/tfjs-core'; import { CppDType } from './types'; let wasmMaxPoolWithArgmax; function setup(backend) { wasmMaxPoolWithArgmax = backend.wasm.cwrap('MaxPoolWithArgmax', null, [ 'number', 'number', 'number', 'number', 'boolean', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', // padLeft ]); } export function maxPoolWithArgmax(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { filterSize, strides, pad, includeBatchInIndex } = attrs; util.assert(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`); const dilations = [1, 1]; util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad); const pooled = backend.makeOutput(convInfo.outShape, x.dtype); const indexes = backend.makeOutput(convInfo.outShape, 'int32'); wasmMaxPoolWithArgmax(backend.dataIdMap.get(x.dataId).id, backend.dataIdMap.get(pooled.dataId).id, backend.dataIdMap.get(indexes.dataId).id, CppDType[x.dtype], includeBatchInIndex, convInfo.batchSize, convInfo.inChannels, convInfo.inHeight, convInfo.inWidth, convInfo.outHeight, convInfo.outWidth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.top, convInfo.padInfo.left); return [pooled, indexes]; } export const maxPoolWithArgmaxConfig = { kernelName: MaxPoolWithArgmax, backendName: 'wasm', setupFunc: setup, kernelFunc: maxPoolWithArgmax }; //# sourceMappingURL=data:application/json;base64,