@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
56 lines • 7.46 kB
JavaScript
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { Multinomial } from '@tensorflow/tfjs-core';
import { softmax } from './Softmax';
let wasmMultinomial;
function setup(backend) {
wasmMultinomial = backend.wasm.cwrap(Multinomial, null, [
'number',
'number',
'number',
'number',
'number',
'number', // outId
]);
}
export function multinomial(args) {
const { inputs, backend, attrs } = args;
const { logits } = inputs;
const { numSamples, seed, normalized } = attrs;
if (logits.dtype !== 'float32') {
throw new Error(`Tensor logits must have dtype float32, got ${logits.dtype}`);
}
const probabilities = normalized ? logits : softmax({
inputs: { logits },
backend,
attrs: { dim: logits.shape.length - 1 },
});
const [batchSize, numEvents] = probabilities.shape;
const out = backend.makeOutput([batchSize, numSamples], 'int32');
wasmMultinomial(backend.dataIdMap.get(probabilities.dataId).id, batchSize, numEvents, numSamples, seed, backend.dataIdMap.get(out.dataId).id);
if (!normalized) {
backend.disposeData(probabilities.dataId);
}
return out;
}
export const multinomialConfig = {
kernelName: Multinomial,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: multinomial
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTXVsdGlub21pYWwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2FzbS9zcmMva2VybmVscy9NdWx0aW5vbWlhbC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQTJCLFdBQVcsRUFBa0QsTUFBTSx1QkFBdUIsQ0FBQztBQUc3SCxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBRWxDLElBQUksZUFFd0QsQ0FBQztBQUU3RCxTQUFTLEtBQUssQ0FBQyxPQUFvQjtJQUNqQyxlQUFlLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsV0FBVyxFQUFFLElBQUksRUFBRTtRQUN0RCxRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVEsRUFBRyxRQUFRO0tBQ3BCLENBQUMsQ0FBQztBQUNMLENBQUM7QUFFRCxNQUFNLFVBQVUsV0FBVyxDQUFDLElBSTNCO0lBQ0MsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxNQUFNLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDeEIsTUFBTSxFQUFDLFVBQVUsRUFBRSxJQUFJLEVBQUUsVUFBVSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRTdDLElBQUksTUFBTSxDQUFDLEtBQUssS0FBSyxTQUFTLEVBQUU7UUFDOUIsTUFBTSxJQUFJLEtBQUssQ0FDWCw4Q0FBOEMsTUFBTSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDbkU7SUFFRCxNQUFNLGFBQWEsR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsT0FBTyxDQUFDO1FBQ2xELE1BQU0sRUFBRSxFQUFDLE1BQU0sRUFBQztRQUNoQixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsR0FBRyxFQUFFLE1BQU0sQ0FBQyxLQUFLLENBQUMsTUFBTSxHQUFHLENBQUMsRUFBQztLQUN0QyxDQUFDLENBQUM7SUFFSCxNQUFNLENBQUMsU0FBUyxFQUFFLFNBQVMsQ0FBQyxHQUFHLGFBQWEsQ0FBQyxLQUFLLENBQUM7SUFDbkQsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxDQUFDLFNBQVMsRUFBRSxVQUFVLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQztJQUVqRSxlQUFlLENBQ1gsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsYUFBYSxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsRUFDOUMsU0FBUyxFQUNULFNBQVMsRUFDVCxVQUFVLEVBQ1YsSUFBSSxFQUNKLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLEdBQUcsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLENBQ3ZDLENBQUM7SUFDRixJQUFJLENBQUMsVUFBVSxFQUFFO1FBQ2YsT0FBTyxDQUFDLFdBQVcsQ0FBQyxhQUFhLENBQUMsTUFBTSxDQUFDLENBQUM7S0FDM0M7SUFDRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxpQkFBaUIsR0FBaUI7SUFDN0MsVUFBVSxFQUFFLFdBQVc7SUFDdkIsV0FBVyxFQUFFLE1BQU07SUFDbkIsU0FBUyxFQUFFLEtBQUs7SUFDaEIsVUFBVSxFQUFFLFdBQW9DO0NBQ2pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMyBHb29nbGUgTExDLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBNdWx0aW5vbWlhbCwgTXVsdGlub21pYWxBdHRycywgTXVsdGlub21pYWxJbnB1dHMsIFRlbnNvckluZm99IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5pbXBvcnQge3NvZnRtYXh9IGZyb20gJy4vU29mdG1heCc7XG5cbmxldCB3YXNtTXVsdGlub21pYWw6IChcbiAgICBwcm9iYWJpbGl0aWVzSWQ6IG51bWJlciwgYmF0Y2hTaXplOiBudW1iZXIsIG51bUV2ZW50czogbnVtYmVyLFxuICAgIG51bVNhbXBsZXM6IG51bWJlciwgc2VlZDogbnVtYmVyLCBvdXRJZDogbnVtYmVyKSA9PiB2b2lkO1xuXG5mdW5jdGlvbiBzZXR1cChiYWNrZW5kOiBCYWNrZW5kV2FzbSkge1xuICB3YXNtTXVsdGlub21pYWwgPSBiYWNrZW5kLndhc20uY3dyYXAoTXVsdGlub21pYWwsIG51bGwsIFtcbiAgICAnbnVtYmVyJywgIC8vIHByb2JhYmlsaXRpZXNJZFxuICAgICdudW1iZXInLCAgLy8gYmF0Y2hTaXplXG4gICAgJ251bWJlcicsICAvLyBudW1FdmVudHNcbiAgICAnbnVtYmVyJywgIC8vIG51bVNhbXBsZXNcbiAgICAnbnVtYmVyJywgIC8vIHNlZWRcbiAgICAnbnVtYmVyJywgIC8vIG91dElkXG4gIF0pO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gbXVsdGlub21pYWwoYXJnczoge1xuICBpbnB1dHM6IE11bHRpbm9taWFsSW5wdXRzLFxuICBhdHRyczogTXVsdGlub21pYWxBdHRycyxcbiAgYmFja2VuZDogQmFja2VuZFdhc20sXG59KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHtsb2dpdHN9ID0gaW5wdXRzO1xuICBjb25zdCB7bnVtU2FtcGxlcywgc2VlZCwgbm9ybWFsaXplZH0gPSBhdHRycztcblxuICBpZiAobG9naXRzLmR0eXBlICE9PSAnZmxvYXQzMicpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGBUZW5zb3IgbG9naXRzIG11c3QgaGF2ZSBkdHlwZSBmbG9hdDMyLCBnb3QgJHtsb2dpdHMuZHR5cGV9YCk7XG4gIH1cblxuICBjb25zdCBwcm9iYWJpbGl0aWVzID0gbm9ybWFsaXplZCA/IGxvZ2l0cyA6IHNvZnRtYXgoe1xuICAgIGlucHV0czoge2xvZ2l0c30sXG4gICAgYmFja2VuZCxcbiAgICBhdHRyczoge2RpbTogbG9naXRzLnNoYXBlLmxlbmd0aCAtIDF9LFxuICB9KTtcblxuICBjb25zdCBbYmF0Y2hTaXplLCBudW1FdmVudHNdID0gcHJvYmFiaWxpdGllcy5zaGFwZTtcbiAgY29uc3Qgb3V0ID0gYmFja2VuZC5tYWtlT3V0cHV0KFtiYXRjaFNpemUsIG51bVNhbXBsZXNdLCAnaW50MzInKTtcblxuICB3YXNtTXVsdGlub21pYWwoXG4gICAgICBiYWNrZW5kLmRhdGFJZE1hcC5nZXQocHJvYmFiaWxpdGllcy5kYXRhSWQpLmlkLFxuICAgICAgYmF0Y2hTaXplLFxuICAgICAgbnVtRXZlbnRzLFxuICAgICAgbnVtU2FtcGxlcyxcbiAgICAgIHNlZWQsXG4gICAgICBiYWNrZW5kLmRhdGFJZE1hcC5nZXQob3V0LmRhdGFJZCkuaWQsXG4gICk7XG4gIGlmICghbm9ybWFsaXplZCkge1xuICAgIGJhY2tlbmQuZGlzcG9zZURhdGEocHJvYmFiaWxpdGllcy5kYXRhSWQpO1xuICB9XG4gIHJldHVybiBvdXQ7XG59XG5cbmV4cG9ydCBjb25zdCBtdWx0aW5vbWlhbENvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBNdWx0aW5vbWlhbCxcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogbXVsdGlub21pYWwgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19