UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

56 lines 7.46 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { Multinomial } from '@tensorflow/tfjs-core'; import { softmax } from './Softmax'; let wasmMultinomial; function setup(backend) { wasmMultinomial = backend.wasm.cwrap(Multinomial, null, [ 'number', 'number', 'number', 'number', 'number', 'number', // outId ]); } export function multinomial(args) { const { inputs, backend, attrs } = args; const { logits } = inputs; const { numSamples, seed, normalized } = attrs; if (logits.dtype !== 'float32') { throw new Error(`Tensor logits must have dtype float32, got ${logits.dtype}`); } const probabilities = normalized ? logits : softmax({ inputs: { logits }, backend, attrs: { dim: logits.shape.length - 1 }, }); const [batchSize, numEvents] = probabilities.shape; const out = backend.makeOutput([batchSize, numSamples], 'int32'); wasmMultinomial(backend.dataIdMap.get(probabilities.dataId).id, batchSize, numEvents, numSamples, seed, backend.dataIdMap.get(out.dataId).id); if (!normalized) { backend.disposeData(probabilities.dataId); } return out; } export const multinomialConfig = { kernelName: Multinomial, backendName: 'wasm', setupFunc: setup, kernelFunc: multinomial }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTXVsdGlub21pYWwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2FzbS9zcmMva2VybmVscy9NdWx0aW5vbWlhbC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQTJCLFdBQVcsRUFBa0QsTUFBTSx1QkFBdUIsQ0FBQztBQUc3SCxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBRWxDLElBQUksZUFFd0QsQ0FBQztBQUU3RCxTQUFTLEtBQUssQ0FBQyxPQUFvQjtJQUNqQyxlQUFlLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsV0FBVyxFQUFFLElBQUksRUFBRTtRQUN0RCxRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVEsRUFBRyxRQUFRO0tBQ3BCLENBQUMsQ0FBQztBQUNMLENBQUM7QUFFRCxNQUFNLFVBQVUsV0FBVyxDQUFDLElBSTNCO0lBQ0MsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxNQUFNLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDeEIsTUFBTSxFQUFDLFVBQVUsRUFBRSxJQUFJLEVBQUUsVUFBVSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRTdDLElBQUksTUFBTSxDQUFDLEtBQUssS0FBSyxTQUFTLEVBQUU7UUFDOUIsTUFBTSxJQUFJLEtBQUssQ0FDWCw4Q0FBOEMsTUFBTSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDbkU7SUFFRCxNQUFNLGFBQWEsR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsT0FBTyxDQUFDO1FBQ2xELE1BQU0sRUFBRSxFQUFDLE1BQU0sRUFBQztRQUNoQixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsR0FBRyxFQUFFLE1BQU0sQ0FBQyxLQUFLLENBQUMsTUFBTSxHQUFHLENBQUMsRUFBQztLQUN0QyxDQUFDLENBQUM7SUFFSCxNQUFNLENBQUMsU0FBUyxFQUFFLFNBQVMsQ0FBQyxHQUFHLGFBQWEsQ0FBQyxLQUFLLENBQUM7SUFDbkQsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxDQUFDLFNBQVMsRUFBRSxVQUFVLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQztJQUVqRSxlQUFlLENBQ1gsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsYUFBYSxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsRUFDOUMsU0FBUyxFQUNULFNBQVMsRUFDVCxVQUFVLEVBQ1YsSUFBSSxFQUNKLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLEdBQUcsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLENBQ3ZDLENBQUM7SUFDRixJQUFJLENBQUMsVUFBVSxFQUFFO1FBQ2YsT0FBTyxDQUFDLFdBQVcsQ0FBQyxhQUFhLENBQUMsTUFBTSxDQUFDLENBQUM7S0FDM0M7SUFDRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxpQkFBaUIsR0FBaUI7SUFDN0MsVUFBVSxFQUFFLFdBQVc7SUFDdkIsV0FBVyxFQUFFLE1BQU07SUFDbkIsU0FBUyxFQUFFLEtBQUs7SUFDaEIsVUFBVSxFQUFFLFdBQW9DO0NBQ2pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMyBHb29nbGUgTExDLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBNdWx0aW5vbWlhbCwgTXVsdGlub21pYWxBdHRycywgTXVsdGlub21pYWxJbnB1dHMsIFRlbnNvckluZm99IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5pbXBvcnQge3NvZnRtYXh9IGZyb20gJy4vU29mdG1heCc7XG5cbmxldCB3YXNtTXVsdGlub21pYWw6IChcbiAgICBwcm9iYWJpbGl0aWVzSWQ6IG51bWJlciwgYmF0Y2hTaXplOiBudW1iZXIsIG51bUV2ZW50czogbnVtYmVyLFxuICAgIG51bVNhbXBsZXM6IG51bWJlciwgc2VlZDogbnVtYmVyLCBvdXRJZDogbnVtYmVyKSA9PiB2b2lkO1xuXG5mdW5jdGlvbiBzZXR1cChiYWNrZW5kOiBCYWNrZW5kV2FzbSkge1xuICB3YXNtTXVsdGlub21pYWwgPSBiYWNrZW5kLndhc20uY3dyYXAoTXVsdGlub21pYWwsIG51bGwsIFtcbiAgICAnbnVtYmVyJywgIC8vIHByb2JhYmlsaXRpZXNJZFxuICAgICdudW1iZXInLCAgLy8gYmF0Y2hTaXplXG4gICAgJ251bWJlcicsICAvLyBudW1FdmVudHNcbiAgICAnbnVtYmVyJywgIC8vIG51bVNhbXBsZXNcbiAgICAnbnVtYmVyJywgIC8vIHNlZWRcbiAgICAnbnVtYmVyJywgIC8vIG91dElkXG4gIF0pO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gbXVsdGlub21pYWwoYXJnczoge1xuICBpbnB1dHM6IE11bHRpbm9taWFsSW5wdXRzLFxuICBhdHRyczogTXVsdGlub21pYWxBdHRycyxcbiAgYmFja2VuZDogQmFja2VuZFdhc20sXG59KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHtsb2dpdHN9ID0gaW5wdXRzO1xuICBjb25zdCB7bnVtU2FtcGxlcywgc2VlZCwgbm9ybWFsaXplZH0gPSBhdHRycztcblxuICBpZiAobG9naXRzLmR0eXBlICE9PSAnZmxvYXQzMicpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGBUZW5zb3IgbG9naXRzIG11c3QgaGF2ZSBkdHlwZSBmbG9hdDMyLCBnb3QgJHtsb2dpdHMuZHR5cGV9YCk7XG4gIH1cblxuICBjb25zdCBwcm9iYWJpbGl0aWVzID0gbm9ybWFsaXplZCA/IGxvZ2l0cyA6IHNvZnRtYXgoe1xuICAgIGlucHV0czoge2xvZ2l0c30sXG4gICAgYmFja2VuZCxcbiAgICBhdHRyczoge2RpbTogbG9naXRzLnNoYXBlLmxlbmd0aCAtIDF9LFxuICB9KTtcblxuICBjb25zdCBbYmF0Y2hTaXplLCBudW1FdmVudHNdID0gcHJvYmFiaWxpdGllcy5zaGFwZTtcbiAgY29uc3Qgb3V0ID0gYmFja2VuZC5tYWtlT3V0cHV0KFtiYXRjaFNpemUsIG51bVNhbXBsZXNdLCAnaW50MzInKTtcblxuICB3YXNtTXVsdGlub21pYWwoXG4gICAgICBiYWNrZW5kLmRhdGFJZE1hcC5nZXQocHJvYmFiaWxpdGllcy5kYXRhSWQpLmlkLFxuICAgICAgYmF0Y2hTaXplLFxuICAgICAgbnVtRXZlbnRzLFxuICAgICAgbnVtU2FtcGxlcyxcbiAgICAgIHNlZWQsXG4gICAgICBiYWNrZW5kLmRhdGFJZE1hcC5nZXQob3V0LmRhdGFJZCkuaWQsXG4gICk7XG4gIGlmICghbm9ybWFsaXplZCkge1xuICAgIGJhY2tlbmQuZGlzcG9zZURhdGEocHJvYmFiaWxpdGllcy5kYXRhSWQpO1xuICB9XG4gIHJldHVybiBvdXQ7XG59XG5cbmV4cG9ydCBjb25zdCBtdWx0aW5vbWlhbENvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBNdWx0aW5vbWlhbCxcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogbXVsdGlub21pYWwgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19