@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
56 lines • 7.5 kB
JavaScript
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
import { denseBincount, mul } from '@tensorflow/tfjs-core';
import { getExactlyOneTensor } from '../../utils/types_utils';
import { expandDims } from '@tensorflow/tfjs-core';
import { ValueError } from '../../errors';
import * as K from '../../backend/tfjs_backend';
export function encodeCategoricalInputs(inputs, outputMode, depth, weights) {
let input = getExactlyOneTensor(inputs);
if (input.dtype !== 'int32') {
input = K.cast(input, 'int32');
}
if (outputMode === 'int') {
return input;
}
const originalShape = input.shape;
if (input.rank === 0) {
input = expandDims(input, -1);
}
if (outputMode === 'oneHot') {
if (input.shape[input.shape.length - 1] !== 1) {
input = expandDims(input, -1);
}
}
if (input.rank > 2) {
throw new ValueError(`When outputMode is not int, maximum output rank is 2`
+ ` Received outputMode ${outputMode} and input shape ${originalShape}`
+ ` which would result in output rank ${input.rank}.`);
}
const binaryOutput = ['multiHot', 'oneHot'].includes(outputMode);
const denseBincountInput = input;
let binCounts;
if ((typeof weights) !== 'undefined' && outputMode === 'count') {
binCounts = denseBincount(denseBincountInput, weights, depth, binaryOutput);
}
else {
binCounts = denseBincount(denseBincountInput, [], depth, binaryOutput);
}
if (outputMode !== 'tfIdf') {
return binCounts;
}
if (weights) {
return mul(binCounts, weights);
}
else {
throw new ValueError(`When outputMode is 'tfIdf', weights must be provided.`);
}
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoicHJlcHJvY2Vzc2luZ191dGlscy5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uLy4uL3RmanMtbGF5ZXJzL3NyYy9sYXllcnMvcHJlcHJvY2Vzc2luZy9wcmVwcm9jZXNzaW5nX3V0aWxzLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7OztHQVFHO0FBRUgsT0FBTyxFQUFVLGFBQWEsRUFBa0MsR0FBRyxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFDbEcsT0FBTyxFQUFFLG1CQUFtQixFQUFFLE1BQU0seUJBQXlCLENBQUM7QUFDOUQsT0FBTyxFQUFFLFVBQVUsRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBQ2xELE9BQU8sRUFBRSxVQUFVLEVBQUUsTUFBTSxjQUFjLENBQUM7QUFDMUMsT0FBTyxLQUFLLENBQUMsTUFBTSw0QkFBNEIsQ0FBQztBQUloRCxNQUFNLFVBQVUsdUJBQXVCLENBQUMsTUFBdUIsRUFDdkIsVUFBc0IsRUFDdEIsS0FBYSxFQUNiLE9BQXNDO0lBRzVFLElBQUksS0FBSyxHQUFHLG1CQUFtQixDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBRXhDLElBQUcsS0FBSyxDQUFDLEtBQUssS0FBSyxPQUFPLEVBQUU7UUFDMUIsS0FBSyxHQUFHLENBQUMsQ0FBQyxJQUFJLENBQUMsS0FBSyxFQUFFLE9BQU8sQ0FBQyxDQUFDO0tBQzlCO0lBRUgsSUFBRyxVQUFVLEtBQUssS0FBSyxFQUFFO1FBQ3ZCLE9BQU8sS0FBSyxDQUFDO0tBQ2Q7SUFFRCxNQUFNLGFBQWEsR0FBRyxLQUFLLENBQUMsS0FBSyxDQUFDO0lBRWxDLElBQUcsS0FBSyxDQUFDLElBQUksS0FBSyxDQUFDLEVBQUU7UUFDbkIsS0FBSyxHQUFHLFVBQVUsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztLQUMvQjtJQUVELElBQUcsVUFBVSxLQUFLLFFBQVEsRUFBRTtRQUMxQixJQUFHLEtBQUssQ0FBQyxLQUFLLENBQUMsS0FBSyxDQUFDLEtBQUssQ0FBQyxNQUFNLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxFQUFFO1lBQzVDLEtBQUssR0FBRyxVQUFVLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUM7U0FDL0I7S0FDRjtJQUVELElBQUcsS0FBSyxDQUFDLElBQUksR0FBRyxDQUFDLEVBQUU7UUFDakIsTUFBTSxJQUFJLFVBQVUsQ0FBQyxzREFBc0Q7Y0FDekUsd0JBQXdCLFVBQVUsb0JBQW9CLGFBQWEsRUFBRTtjQUNyRSxzQ0FBc0MsS0FBSyxDQUFDLElBQUksR0FBRyxDQUFDLENBQUM7S0FDeEQ7SUFFRCxNQUFNLFlBQVksR0FBRyxDQUFDLFVBQVUsRUFBRSxRQUFRLENBQUMsQ0FBQyxRQUFRLENBQUMsVUFBVSxDQUFDLENBQUM7SUFFakUsTUFBTSxrQkFBa0IsR0FBRyxLQUE0QixDQUFDO0lBRXhELElBQUksU0FBOEIsQ0FBQztJQUVuQyxJQUFJLENBQUMsT0FBTyxPQUFPLENBQUMsS0FBSyxXQUFXLElBQUksVUFBVSxLQUFLLE9BQU8sRUFBRTtRQUM5RCxTQUFTLEdBQUcsYUFBYSxDQUFDLGtCQUFrQixFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsWUFBWSxDQUFDLENBQUM7S0FDNUU7U0FBTTtRQUNOLFNBQVMsR0FBRyxhQUFhLENBQUMsa0JBQWtCLEVBQUUsRUFBRSxFQUFFLEtBQUssRUFBRSxZQUFZLENBQUMsQ0FBQztLQUN2RTtJQUVGLElBQUcsVUFBVSxLQUFLLE9BQU8sRUFBRTtRQUN6QixPQUFPLFNBQVMsQ0FBQztLQUNsQjtJQUVELElBQUksT0FBTyxFQUFFO1FBQ1gsT0FBTyxHQUFHLENBQUMsU0FBUyxFQUFFLE9BQU8sQ0FBQyxDQUFDO0tBQ2hDO1NBQU07UUFDSCxNQUFNLElBQUksVUFBVSxDQUNsQix1REFBdUQsQ0FDeEQsQ0FBQztLQUNMO0FBQ0gsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIyIENvZGVTbWl0aCBMTENcbiAqXG4gKiBVc2Ugb2YgdGhpcyBzb3VyY2UgY29kZSBpcyBnb3Zlcm5lZCBieSBhbiBNSVQtc3R5bGVcbiAqIGxpY2Vuc2UgdGhhdCBjYW4gYmUgZm91bmQgaW4gdGhlIExJQ0VOU0UgZmlsZSBvciBhdFxuICogaHR0cHM6Ly9vcGVuc291cmNlLm9yZy9saWNlbnNlcy9NSVQuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7IFRlbnNvciwgZGVuc2VCaW5jb3VudCwgVGVuc29yMUQsIFRlbnNvcjJELCBUZW5zb3JMaWtlLCBtdWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5pbXBvcnQgeyBnZXRFeGFjdGx5T25lVGVuc29yIH0gZnJvbSAnLi4vLi4vdXRpbHMvdHlwZXNfdXRpbHMnO1xuaW1wb3J0IHsgZXhwYW5kRGltc30gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcbmltcG9ydCB7IFZhbHVlRXJyb3IgfSBmcm9tICcuLi8uLi9lcnJvcnMnO1xuaW1wb3J0ICogYXMgSyBmcm9tICcuLi8uLi9iYWNrZW5kL3RmanNfYmFja2VuZCc7XG5cbmV4cG9ydCB0eXBlIE91dHB1dE1vZGUgPSAnaW50JyB8ICdvbmVIb3QnIHwgJ211bHRpSG90JyB8ICdjb3VudCcgfCAndGZJZGYnO1xuXG5leHBvcnQgZnVuY3Rpb24gZW5jb2RlQ2F0ZWdvcmljYWxJbnB1dHMoaW5wdXRzOiBUZW5zb3J8VGVuc29yW10sXG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgb3V0cHV0TW9kZTogT3V0cHV0TW9kZSxcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBkZXB0aDogbnVtYmVyLFxuICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHdlaWdodHM/OiBUZW5zb3IxRHxUZW5zb3IyRHxUZW5zb3JMaWtlKTpcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBUZW5zb3J8VGVuc29yW10ge1xuXG4gIGxldCBpbnB1dCA9IGdldEV4YWN0bHlPbmVUZW5zb3IoaW5wdXRzKTtcblxuICBpZihpbnB1dC5kdHlwZSAhPT0gJ2ludDMyJykge1xuICAgIGlucHV0ID0gSy5jYXN0KGlucHV0LCAnaW50MzInKTtcbiAgICB9XG5cbiAgaWYob3V0cHV0TW9kZSA9PT0gJ2ludCcpIHtcbiAgICByZXR1cm4gaW5wdXQ7XG4gIH1cblxuICBjb25zdCBvcmlnaW5hbFNoYXBlID0gaW5wdXQuc2hhcGU7XG5cbiAgaWYoaW5wdXQucmFuayA9PT0gMCkge1xuICAgIGlucHV0ID0gZXhwYW5kRGltcyhpbnB1dCwgLTEpO1xuICB9XG5cbiAgaWYob3V0cHV0TW9kZSA9PT0gJ29uZUhvdCcpIHtcbiAgICBpZihpbnB1dC5zaGFwZVtpbnB1dC5zaGFwZS5sZW5ndGggLSAxXSAhPT0gMSkge1xuICAgICAgaW5wdXQgPSBleHBhbmREaW1zKGlucHV0LCAtMSk7XG4gICAgfVxuICB9XG5cbiAgaWYoaW5wdXQucmFuayA+IDIpIHtcbiAgICB0aHJvdyBuZXcgVmFsdWVFcnJvcihgV2hlbiBvdXRwdXRNb2RlIGlzIG5vdCBpbnQsIG1heGltdW0gb3V0cHV0IHJhbmsgaXMgMmBcbiAgICArIGAgUmVjZWl2ZWQgb3V0cHV0TW9kZSAke291dHB1dE1vZGV9IGFuZCBpbnB1dCBzaGFwZSAke29yaWdpbmFsU2hhcGV9YFxuICAgICsgYCB3aGljaCB3b3VsZCByZXN1bHQgaW4gb3V0cHV0IHJhbmsgJHtpbnB1dC5yYW5rfS5gKTtcbiAgfVxuXG4gIGNvbnN0IGJpbmFyeU91dHB1dCA9IFsnbXVsdGlIb3QnLCAnb25lSG90J10uaW5jbHVkZXMob3V0cHV0TW9kZSk7XG5cbiAgY29uc3QgZGVuc2VCaW5jb3VudElucHV0ID0gaW5wdXQgYXMgVGVuc29yMUQgfCBUZW5zb3IyRDtcblxuICBsZXQgYmluQ291bnRzOiBUZW5zb3IxRCB8IFRlbnNvcjJEO1xuXG4gIGlmICgodHlwZW9mIHdlaWdodHMpICE9PSAndW5kZWZpbmVkJyAmJiBvdXRwdXRNb2RlID09PSAnY291bnQnKSB7XG4gICAgYmluQ291bnRzID0gZGVuc2VCaW5jb3VudChkZW5zZUJpbmNvdW50SW5wdXQsIHdlaWdodHMsIGRlcHRoLCBiaW5hcnlPdXRwdXQpO1xuICAgfSBlbHNlIHtcbiAgICBiaW5Db3VudHMgPSBkZW5zZUJpbmNvdW50KGRlbnNlQmluY291bnRJbnB1dCwgW10sIGRlcHRoLCBiaW5hcnlPdXRwdXQpO1xuICAgfVxuXG4gIGlmKG91dHB1dE1vZGUgIT09ICd0ZklkZicpIHtcbiAgICByZXR1cm4gYmluQ291bnRzO1xuICB9XG5cbiAgaWYgKHdlaWdodHMpIHtcbiAgICByZXR1cm4gbXVsKGJpbkNvdW50cywgd2VpZ2h0cyk7XG4gIH0gZWxzZSB7XG4gICAgICB0aHJvdyBuZXcgVmFsdWVFcnJvcihcbiAgICAgICAgYFdoZW4gb3V0cHV0TW9kZSBpcyAndGZJZGYnLCB3ZWlnaHRzIG11c3QgYmUgcHJvdmlkZWQuYFxuICAgICAgKTtcbiAgfVxufVxuIl19