UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

81 lines 11.4 kB
/** * @license * Copyright 2022 CodeSmith LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ import { Layer } from '../../engine/topology'; import { serialization, tidy } from '@tensorflow/tfjs-core'; import { greater, greaterEqual, max, min } from '@tensorflow/tfjs-core'; import { getExactlyOneShape, getExactlyOneTensor } from '../../utils/types_utils'; import { ValueError } from '../../errors'; import * as K from '../../backend/tfjs_backend'; import * as utils from './preprocessing_utils'; class CategoryEncoding extends Layer { constructor(args) { super(args); this.numTokens = args.numTokens; if (args.outputMode) { this.outputMode = args.outputMode; } else { this.outputMode = 'multiHot'; } } getConfig() { const config = { 'numTokens': this.numTokens, 'outputMode': this.outputMode, }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } computeOutputShape(inputShape) { inputShape = getExactlyOneShape(inputShape); if (inputShape == null) { return [this.numTokens]; } if (this.outputMode === 'oneHot' && inputShape[inputShape.length - 1] !== 1) { inputShape.push(this.numTokens); return inputShape; } inputShape[inputShape.length - 1] = this.numTokens; return inputShape; } call(inputs, kwargs) { return tidy(() => { inputs = getExactlyOneTensor(inputs); if (inputs.dtype !== 'int32') { inputs = K.cast(inputs, 'int32'); } let countWeights; if ((typeof kwargs['countWeights']) !== 'undefined') { if (this.outputMode !== 'count') { throw new ValueError(`countWeights is not used when outputMode !== count. Received countWeights=${kwargs['countWeights']}`); } countWeights = getExactlyOneTensor(kwargs['countWeights']); } const maxValue = max(inputs); const minValue = min(inputs); const greaterEqualMax = greater(this.numTokens, maxValue) .bufferSync().get(0); const greaterMin = greaterEqual(minValue, 0).bufferSync().get(0); if (!(greaterEqualMax && greaterMin)) { throw new ValueError('Input values must be between 0 < values <=' + ` numTokens with numTokens=${this.numTokens}`); } return utils.encodeCategoricalInputs(inputs, this.outputMode, this.numTokens, countWeights); }); } } /** @nocollapse */ CategoryEncoding.className = 'CategoryEncoding'; export { CategoryEncoding }; serialization.registerClass(CategoryEncoding); //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY2F0ZWdvcnlfZW5jb2RpbmcuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWxheWVycy9zcmMvbGF5ZXJzL3ByZXByb2Nlc3NpbmcvY2F0ZWdvcnlfZW5jb2RpbmcudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7O0dBUUc7QUFFSCxPQUFPLEVBQWEsS0FBSyxFQUFFLE1BQU0sdUJBQXVCLENBQUM7QUFDekQsT0FBTyxFQUFFLGFBQWEsRUFBVSxJQUFJLEVBQXFCLE1BQU0sdUJBQXVCLENBQUM7QUFDdkYsT0FBTyxFQUFFLE9BQU8sRUFBRSxZQUFZLEVBQUUsR0FBRyxFQUFFLEdBQUcsRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBRXZFLE9BQU8sRUFBRSxrQkFBa0IsRUFBRSxtQkFBbUIsRUFBRSxNQUFNLHlCQUF5QixDQUFDO0FBRWxGLE9BQU8sRUFBRSxVQUFVLEVBQUUsTUFBTSxjQUFjLENBQUM7QUFDMUMsT0FBTyxLQUFLLENBQUMsTUFBTSw0QkFBNEIsQ0FBQztBQUNoRCxPQUFPLEtBQUssS0FBSyxNQUFNLHVCQUF1QixDQUFDO0FBUS9DLE1BQWEsZ0JBQWlCLFNBQVEsS0FBSztJQU16QyxZQUFZLElBQTBCO1FBQ3BDLEtBQUssQ0FBQyxJQUFJLENBQUMsQ0FBQztRQUNaLElBQUksQ0FBQyxTQUFTLEdBQUcsSUFBSSxDQUFDLFNBQVMsQ0FBQztRQUVoQyxJQUFHLElBQUksQ0FBQyxVQUFVLEVBQUU7WUFDcEIsSUFBSSxDQUFDLFVBQVUsR0FBRyxJQUFJLENBQUMsVUFBVSxDQUFDO1NBQ2pDO2FBQU07WUFDTCxJQUFJLENBQUMsVUFBVSxHQUFHLFVBQVUsQ0FBQztTQUM5QjtJQUNILENBQUM7SUFFUSxTQUFTO1FBQ2hCLE1BQU0sTUFBTSxHQUE2QjtZQUN2QyxXQUFXLEVBQUUsSUFBSSxDQUFDLFNBQVM7WUFDM0IsWUFBWSxFQUFFLElBQUksQ0FBQyxVQUFVO1NBQzlCLENBQUM7UUFFRixNQUFNLFVBQVUsR0FBRyxLQUFLLENBQUMsU0FBUyxFQUFFLENBQUM7UUFDckMsTUFBTSxDQUFDLE1BQU0sQ0FBQyxNQUFNLEVBQUUsVUFBVSxDQUFDLENBQUM7UUFDbEMsT0FBTyxNQUFNLENBQUM7SUFDaEIsQ0FBQztJQUVRLGtCQUFrQixDQUFDLFVBQXlCO1FBQ25ELFVBQVUsR0FBRyxrQkFBa0IsQ0FBQyxVQUFVLENBQUMsQ0FBQztRQUU1QyxJQUFHLFVBQVUsSUFBSSxJQUFJLEVBQUU7WUFDckIsT0FBTyxDQUFDLElBQUksQ0FBQyxTQUFTLENBQUMsQ0FBQztTQUN6QjtRQUVELElBQUcsSUFBSSxDQUFDLFVBQVUsS0FBSyxRQUFRLElBQUksVUFBVSxDQUFDLFVBQVUsQ0FBQyxNQUFNLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxFQUFDO1lBQ3pFLFVBQVUsQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxDQUFDO1lBQ2hDLE9BQU8sVUFBVSxDQUFDO1NBQ25CO1FBRUQsVUFBVSxDQUFDLFVBQVUsQ0FBQyxNQUFNLEdBQUcsQ0FBQyxDQUFDLEdBQUcsSUFBSSxDQUFDLFNBQVMsQ0FBQztRQUNuRCxPQUFPLFVBQVUsQ0FBQztJQUNwQixDQUFDO0lBRVEsSUFBSSxDQUFDLE1BQXVCLEVBQUUsTUFBYztRQUNuRCxPQUFPLElBQUksQ0FBQyxHQUFHLEVBQUU7WUFFYixNQUFNLEdBQUcsbUJBQW1CLENBQUMsTUFBTSxDQUFDLENBQUM7WUFDckMsSUFBRyxNQUFNLENBQUMsS0FBSyxLQUFLLE9BQU8sRUFBRTtnQkFDM0IsTUFBTSxHQUFHLENBQUMsQ0FBQyxJQUFJLENBQUMsTUFBTSxFQUFFLE9BQU8sQ0FBQyxDQUFDO2FBQ3BDO1lBRUMsSUFBSSxZQUFpQyxDQUFDO1lBRXRDLElBQUcsQ0FBQyxPQUFPLE1BQU0sQ0FBQyxjQUFjLENBQUMsQ0FBQyxLQUFLLFdBQVcsRUFBRTtnQkFFbEQsSUFBRyxJQUFJLENBQUMsVUFBVSxLQUFLLE9BQU8sRUFBRTtvQkFDOUIsTUFBTSxJQUFJLFVBQVUsQ0FDbEI7c0NBQ3dCLE1BQU0sQ0FBQyxjQUFjLENBQUMsRUFBRSxDQUFDLENBQUM7aUJBQ3JEO2dCQUVELFlBQVk7c0JBQ1AsbUJBQW1CLENBQUMsTUFBTSxDQUFDLGNBQWMsQ0FBQyxDQUFzQixDQUFDO2FBQ3ZFO1lBRUQsTUFBTSxRQUFRLEdBQUcsR0FBRyxDQUFDLE1BQU0sQ0FBQyxDQUFDO1lBQzdCLE1BQU0sUUFBUSxHQUFHLEdBQUcsQ0FBQyxNQUFNLENBQUMsQ0FBQztZQUM3QixNQUFNLGVBQWUsR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLFNBQVMsRUFBRSxRQUFRLENBQUM7aUJBQ1osVUFBVSxFQUFFLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDO1lBRWpFLE1BQU0sVUFBVSxHQUFHLFlBQVksQ0FBQyxRQUFRLEVBQUUsQ0FBQyxDQUFDLENBQUMsVUFBVSxFQUFFLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDO1lBRWpFLElBQUcsQ0FBQyxDQUFDLGVBQWUsSUFBSSxVQUFVLENBQUMsRUFBRTtnQkFFbkMsTUFBTSxJQUFJLFVBQVUsQ0FBQyw0Q0FBNEM7c0JBQzdELDZCQUE2QixJQUFJLENBQUMsU0FBUyxFQUFFLENBQUMsQ0FBQzthQUNwRDtZQUVELE9BQU8sS0FBSyxDQUFDLHVCQUF1QixDQUFDLE1BQU0sRUFDekMsSUFBSSxDQUFDLFVBQVUsRUFBRSxJQUFJLENBQUMsU0FBUyxFQUFFLFlBQVksQ0FBQyxDQUFDO1FBQ3JELENBQUMsQ0FBQyxDQUFDO0lBQ0wsQ0FBQzs7QUFqRkQsa0JBQWtCO0FBQ1gsMEJBQVMsR0FBRyxrQkFBa0IsQ0FBQztTQUYzQixnQkFBZ0I7QUFxRjdCLGFBQWEsQ0FBQyxhQUFhLENBQUMsZ0JBQWdCLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIyIENvZGVTbWl0aCBMTENcbiAqXG4gKiBVc2Ugb2YgdGhpcyBzb3VyY2UgY29kZSBpcyBnb3Zlcm5lZCBieSBhbiBNSVQtc3R5bGVcbiAqIGxpY2Vuc2UgdGhhdCBjYW4gYmUgZm91bmQgaW4gdGhlIExJQ0VOU0UgZmlsZSBvciBhdFxuICogaHR0cHM6Ly9vcGVuc291cmNlLm9yZy9saWNlbnNlcy9NSVQuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7IExheWVyQXJncywgTGF5ZXIgfSBmcm9tICcuLi8uLi9lbmdpbmUvdG9wb2xvZ3knO1xuaW1wb3J0IHsgc2VyaWFsaXphdGlvbiwgVGVuc29yLCB0aWR5LCBUZW5zb3IxRCwgVGVuc29yMkR9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5pbXBvcnQgeyBncmVhdGVyLCBncmVhdGVyRXF1YWwsIG1heCwgbWlufSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuaW1wb3J0IHsgU2hhcGUgfSBmcm9tICcuLi8uLi9rZXJhc19mb3JtYXQvY29tbW9uJztcbmltcG9ydCB7IGdldEV4YWN0bHlPbmVTaGFwZSwgZ2V0RXhhY3RseU9uZVRlbnNvciB9IGZyb20gJy4uLy4uL3V0aWxzL3R5cGVzX3V0aWxzJztcbmltcG9ydCB7IEt3YXJncyB9IGZyb20gJy4uLy4uL3R5cGVzJztcbmltcG9ydCB7IFZhbHVlRXJyb3IgfSBmcm9tICcuLi8uLi9lcnJvcnMnO1xuaW1wb3J0ICogYXMgSyBmcm9tICcuLi8uLi9iYWNrZW5kL3RmanNfYmFja2VuZCc7XG5pbXBvcnQgKiBhcyB1dGlscyBmcm9tICcuL3ByZXByb2Nlc3NpbmdfdXRpbHMnO1xuaW1wb3J0IHsgT3V0cHV0TW9kZSB9IGZyb20gJy4vcHJlcHJvY2Vzc2luZ191dGlscyc7XG5cbmV4cG9ydCBkZWNsYXJlIGludGVyZmFjZSBDYXRlZ29yeUVuY29kaW5nQXJncyBleHRlbmRzIExheWVyQXJncyB7XG4gIG51bVRva2VuczogbnVtYmVyO1xuICBvdXRwdXRNb2RlPzogT3V0cHV0TW9kZTtcbiB9XG5cbmV4cG9ydCBjbGFzcyBDYXRlZ29yeUVuY29kaW5nIGV4dGVuZHMgTGF5ZXIge1xuICAvKiogQG5vY29sbGFwc2UgKi9cbiAgc3RhdGljIGNsYXNzTmFtZSA9ICdDYXRlZ29yeUVuY29kaW5nJztcbiAgcHJpdmF0ZSByZWFkb25seSBudW1Ub2tlbnM6IG51bWJlcjtcbiAgcHJpdmF0ZSByZWFkb25seSBvdXRwdXRNb2RlOiBPdXRwdXRNb2RlO1xuXG4gIGNvbnN0cnVjdG9yKGFyZ3M6IENhdGVnb3J5RW5jb2RpbmdBcmdzKSB7XG4gICAgc3VwZXIoYXJncyk7XG4gICAgdGhpcy5udW1Ub2tlbnMgPSBhcmdzLm51bVRva2VucztcblxuICAgIGlmKGFyZ3Mub3V0cHV0TW9kZSkge1xuICAgIHRoaXMub3V0cHV0TW9kZSA9IGFyZ3Mub3V0cHV0TW9kZTtcbiAgICB9IGVsc2Uge1xuICAgICAgdGhpcy5vdXRwdXRNb2RlID0gJ211bHRpSG90JztcbiAgICB9XG4gIH1cblxuICBvdmVycmlkZSBnZXRDb25maWcoKTogc2VyaWFsaXphdGlvbi5Db25maWdEaWN0IHtcbiAgICBjb25zdCBjb25maWc6IHNlcmlhbGl6YXRpb24uQ29uZmlnRGljdCA9IHtcbiAgICAgICdudW1Ub2tlbnMnOiB0aGlzLm51bVRva2VucyxcbiAgICAgICdvdXRwdXRNb2RlJzogdGhpcy5vdXRwdXRNb2RlLFxuICAgIH07XG5cbiAgICBjb25zdCBiYXNlQ29uZmlnID0gc3VwZXIuZ2V0Q29uZmlnKCk7XG4gICAgT2JqZWN0LmFzc2lnbihjb25maWcsIGJhc2VDb25maWcpO1xuICAgIHJldHVybiBjb25maWc7XG4gIH1cblxuICBvdmVycmlkZSBjb21wdXRlT3V0cHV0U2hhcGUoaW5wdXRTaGFwZTogU2hhcGV8U2hhcGVbXSk6IFNoYXBlfFNoYXBlW10ge1xuICAgIGlucHV0U2hhcGUgPSBnZXRFeGFjdGx5T25lU2hhcGUoaW5wdXRTaGFwZSk7XG5cbiAgICBpZihpbnB1dFNoYXBlID09IG51bGwpIHtcbiAgICAgIHJldHVybiBbdGhpcy5udW1Ub2tlbnNdO1xuICAgIH1cblxuICAgIGlmKHRoaXMub3V0cHV0TW9kZSA9PT0gJ29uZUhvdCcgJiYgaW5wdXRTaGFwZVtpbnB1dFNoYXBlLmxlbmd0aCAtIDFdICE9PSAxKXtcbiAgICAgIGlucHV0U2hhcGUucHVzaCh0aGlzLm51bVRva2Vucyk7XG4gICAgICByZXR1cm4gaW5wdXRTaGFwZTtcbiAgICB9XG5cbiAgICBpbnB1dFNoYXBlW2lucHV0U2hhcGUubGVuZ3RoIC0gMV0gPSB0aGlzLm51bVRva2VucztcbiAgICByZXR1cm4gaW5wdXRTaGFwZTtcbiAgfVxuXG4gIG92ZXJyaWRlIGNhbGwoaW5wdXRzOiBUZW5zb3J8VGVuc29yW10sIGt3YXJnczogS3dhcmdzKTogVGVuc29yW118VGVuc29yIHtcbiAgICByZXR1cm4gdGlkeSgoKSA9PiB7XG5cbiAgICAgICAgaW5wdXRzID0gZ2V0RXhhY3RseU9uZVRlbnNvcihpbnB1dHMpO1xuICAgICAgICBpZihpbnB1dHMuZHR5cGUgIT09ICdpbnQzMicpIHtcbiAgICAgICAgICBpbnB1dHMgPSBLLmNhc3QoaW5wdXRzLCAnaW50MzInKTtcbiAgICAgIH1cblxuICAgICAgICBsZXQgY291bnRXZWlnaHRzOiBUZW5zb3IxRCB8IFRlbnNvcjJEO1xuXG4gICAgICAgIGlmKCh0eXBlb2Yga3dhcmdzWydjb3VudFdlaWdodHMnXSkgIT09ICd1bmRlZmluZWQnKSB7XG5cbiAgICAgICAgICBpZih0aGlzLm91dHB1dE1vZGUgIT09ICdjb3VudCcpIHtcbiAgICAgICAgICAgIHRocm93IG5ldyBWYWx1ZUVycm9yKFxuICAgICAgICAgICAgICBgY291bnRXZWlnaHRzIGlzIG5vdCB1c2VkIHdoZW4gb3V0cHV0TW9kZSAhPT0gY291bnQuXG4gICAgICAgICAgICAgIFJlY2VpdmVkIGNvdW50V2VpZ2h0cz0ke2t3YXJnc1snY291bnRXZWlnaHRzJ119YCk7XG4gICAgICAgICAgfVxuXG4gICAgICAgICAgY291bnRXZWlnaHRzXG4gICAgICAgICAgICA9ICBnZXRFeGFjdGx5T25lVGVuc29yKGt3YXJnc1snY291bnRXZWlnaHRzJ10pIGFzIFRlbnNvcjFEfFRlbnNvcjJEO1xuICAgICAgICB9XG5cbiAgICAgICAgY29uc3QgbWF4VmFsdWUgPSBtYXgoaW5wdXRzKTtcbiAgICAgICAgY29uc3QgbWluVmFsdWUgPSBtaW4oaW5wdXRzKTtcbiAgICAgICAgY29uc3QgZ3JlYXRlckVxdWFsTWF4ID0gZ3JlYXRlcih0aGlzLm51bVRva2VucywgbWF4VmFsdWUpXG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLmJ1ZmZlclN5bmMoKS5nZXQoMCk7XG5cbiAgICAgICAgY29uc3QgZ3JlYXRlck1pbiA9IGdyZWF0ZXJFcXVhbChtaW5WYWx1ZSwgMCkuYnVmZmVyU3luYygpLmdldCgwKTtcblxuICAgICAgICBpZighKGdyZWF0ZXJFcXVhbE1heCAmJiBncmVhdGVyTWluKSkge1xuXG4gICAgICAgICAgdGhyb3cgbmV3IFZhbHVlRXJyb3IoJ0lucHV0IHZhbHVlcyBtdXN0IGJlIGJldHdlZW4gMCA8IHZhbHVlcyA8PSdcbiAgICAgICAgICAgICsgYCBudW1Ub2tlbnMgd2l0aCBudW1Ub2tlbnM9JHt0aGlzLm51bVRva2Vuc31gKTtcbiAgICAgICAgfVxuXG4gICAgICAgIHJldHVybiB1dGlscy5lbmNvZGVDYXRlZ29yaWNhbElucHV0cyhpbnB1dHMsXG4gICAgICAgICAgdGhpcy5vdXRwdXRNb2RlLCB0aGlzLm51bVRva2VucywgY291bnRXZWlnaHRzKTtcbiAgICB9KTtcbiAgfVxufVxuXG5zZXJpYWxpemF0aW9uLnJlZ2lzdGVyQ2xhc3MoQ2F0ZWdvcnlFbmNvZGluZyk7XG4iXX0=