UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

81 lines 11.4 kB
/** * @license * Copyright 2022 CodeSmith LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ import { Layer } from '../../engine/topology'; import { serialization, tidy } from '@tensorflow/tfjs-core'; import { greater, greaterEqual, max, min } from '@tensorflow/tfjs-core'; import { getExactlyOneShape, getExactlyOneTensor } from '../../utils/types_utils'; import { ValueError } from '../../errors'; import * as K from '../../backend/tfjs_backend'; import * as utils from './preprocessing_utils'; class CategoryEncoding extends Layer { constructor(args) { super(args); this.numTokens = args.numTokens; if (args.outputMode) { this.outputMode = args.outputMode; } else { this.outputMode = 'multiHot'; } } getConfig() { const config = { 'numTokens': this.numTokens, 'outputMode': this.outputMode, }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } computeOutputShape(inputShape) { inputShape = getExactlyOneShape(inputShape); if (inputShape == null) { return [this.numTokens]; } if (this.outputMode === 'oneHot' && inputShape[inputShape.length - 1] !== 1) { inputShape.push(this.numTokens); return inputShape; } inputShape[inputShape.length - 1] = this.numTokens; return inputShape; } call(inputs, kwargs) { return tidy(() => { inputs = getExactlyOneTensor(inputs); if (inputs.dtype !== 'int32') { inputs = K.cast(inputs, 'int32'); } let countWeights; if ((typeof kwargs['countWeights']) !== 'undefined') { if (this.outputMode !== 'count') { throw new ValueError(`countWeights is not used when outputMode !== count. Received countWeights=${kwargs['countWeights']}`); } countWeights = getExactlyOneTensor(kwargs['countWeights']); } const maxValue = max(inputs); const minValue = min(inputs); const greaterEqualMax = greater(this.numTokens, maxValue) .bufferSync().get(0); const greaterMin = greaterEqual(minValue, 0).bufferSync().get(0); if (!(greaterEqualMax && greaterMin)) { throw new ValueError('Input values must be between 0 < values <=' + ` numTokens with numTokens=${this.numTokens}`); } return utils.encodeCategoricalInputs(inputs, this.outputMode, this.numTokens, countWeights); }); } } /** @nocollapse */ CategoryEncoding.className = 'CategoryEncoding'; export { CategoryEncoding }; serialization.registerClass(CategoryEncoding); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"category_encoding.js","sourceRoot":"","sources":["../../../../../../../tfjs-layers/src/layers/preprocessing/category_encoding.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,OAAO,EAAa,KAAK,EAAE,MAAM,uBAAuB,CAAC;AACzD,OAAO,EAAE,aAAa,EAAU,IAAI,EAAqB,MAAM,uBAAuB,CAAC;AACvF,OAAO,EAAE,OAAO,EAAE,YAAY,EAAE,GAAG,EAAE,GAAG,EAAC,MAAM,uBAAuB,CAAC;AAEvE,OAAO,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,MAAM,yBAAyB,CAAC;AAElF,OAAO,EAAE,UAAU,EAAE,MAAM,cAAc,CAAC;AAC1C,OAAO,KAAK,CAAC,MAAM,4BAA4B,CAAC;AAChD,OAAO,KAAK,KAAK,MAAM,uBAAuB,CAAC;AAQ/C,MAAa,gBAAiB,SAAQ,KAAK;IAMzC,YAAY,IAA0B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;QAEhC,IAAG,IAAI,CAAC,UAAU,EAAE;YACpB,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;SACjC;aAAM;YACL,IAAI,CAAC,UAAU,GAAG,UAAU,CAAC;SAC9B;IACH,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B;YACvC,WAAW,EAAE,IAAI,CAAC,SAAS;YAC3B,YAAY,EAAE,IAAI,CAAC,UAAU;SAC9B,CAAC;QAEF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAE5C,IAAG,UAAU,IAAI,IAAI,EAAE;YACrB,OAAO,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;SACzB;QAED,IAAG,IAAI,CAAC,UAAU,KAAK,QAAQ,IAAI,UAAU,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,KAAK,CAAC,EAAC;YACzE,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;YAChC,OAAO,UAAU,CAAC;SACnB;QAED,UAAU,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,SAAS,CAAC;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,OAAO,IAAI,CAAC,GAAG,EAAE;YAEb,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;YACrC,IAAG,MAAM,CAAC,KAAK,KAAK,OAAO,EAAE;gBAC3B,MAAM,GAAG,CAAC,CAAC,IAAI,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;aACpC;YAEC,IAAI,YAAiC,CAAC;YAEtC,IAAG,CAAC,OAAO,MAAM,CAAC,cAAc,CAAC,CAAC,KAAK,WAAW,EAAE;gBAElD,IAAG,IAAI,CAAC,UAAU,KAAK,OAAO,EAAE;oBAC9B,MAAM,IAAI,UAAU,CAClB;sCACwB,MAAM,CAAC,cAAc,CAAC,EAAE,CAAC,CAAC;iBACrD;gBAED,YAAY;sBACP,mBAAmB,CAAC,MAAM,CAAC,cAAc,CAAC,CAAsB,CAAC;aACvE;YAED,MAAM,QAAQ,GAAG,GAAG,CAAC,MAAM,CAAC,CAAC;YAC7B,MAAM,QAAQ,GAAG,GAAG,CAAC,MAAM,CAAC,CAAC;YAC7B,MAAM,eAAe,GAAG,OAAO,CAAC,IAAI,CAAC,SAAS,EAAE,QAAQ,CAAC;iBACZ,UAAU,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;YAEjE,MAAM,UAAU,GAAG,YAAY,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;YAEjE,IAAG,CAAC,CAAC,eAAe,IAAI,UAAU,CAAC,EAAE;gBAEnC,MAAM,IAAI,UAAU,CAAC,4CAA4C;sBAC7D,6BAA6B,IAAI,CAAC,SAAS,EAAE,CAAC,CAAC;aACpD;YAED,OAAO,KAAK,CAAC,uBAAuB,CAAC,MAAM,EACzC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,SAAS,EAAE,YAAY,CAAC,CAAC;QACrD,CAAC,CAAC,CAAC;IACL,CAAC;;AAjFD,kBAAkB;AACX,0BAAS,GAAG,kBAAkB,CAAC;SAF3B,gBAAgB;AAqF7B,aAAa,CAAC,aAAa,CAAC,gBAAgB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2022 CodeSmith LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\nimport { LayerArgs, Layer } from '../../engine/topology';\nimport { serialization, Tensor, tidy, Tensor1D, Tensor2D} from '@tensorflow/tfjs-core';\nimport { greater, greaterEqual, max, min} from '@tensorflow/tfjs-core';\nimport { Shape } from '../../keras_format/common';\nimport { getExactlyOneShape, getExactlyOneTensor } from '../../utils/types_utils';\nimport { Kwargs } from '../../types';\nimport { ValueError } from '../../errors';\nimport * as K from '../../backend/tfjs_backend';\nimport * as utils from './preprocessing_utils';\nimport { OutputMode } from './preprocessing_utils';\n\nexport declare interface CategoryEncodingArgs extends LayerArgs {\n  numTokens: number;\n  outputMode?: OutputMode;\n }\n\nexport class CategoryEncoding extends Layer {\n  /** @nocollapse */\n  static className = 'CategoryEncoding';\n  private readonly numTokens: number;\n  private readonly outputMode: OutputMode;\n\n  constructor(args: CategoryEncodingArgs) {\n    super(args);\n    this.numTokens = args.numTokens;\n\n    if(args.outputMode) {\n    this.outputMode = args.outputMode;\n    } else {\n      this.outputMode = 'multiHot';\n    }\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {\n      'numTokens': this.numTokens,\n      'outputMode': this.outputMode,\n    };\n\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    inputShape = getExactlyOneShape(inputShape);\n\n    if(inputShape == null) {\n      return [this.numTokens];\n    }\n\n    if(this.outputMode === 'oneHot' && inputShape[inputShape.length - 1] !== 1){\n      inputShape.push(this.numTokens);\n      return inputShape;\n    }\n\n    inputShape[inputShape.length - 1] = this.numTokens;\n    return inputShape;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor[]|Tensor {\n    return tidy(() => {\n\n        inputs = getExactlyOneTensor(inputs);\n        if(inputs.dtype !== 'int32') {\n          inputs = K.cast(inputs, 'int32');\n      }\n\n        let countWeights: Tensor1D | Tensor2D;\n\n        if((typeof kwargs['countWeights']) !== 'undefined') {\n\n          if(this.outputMode !== 'count') {\n            throw new ValueError(\n              `countWeights is not used when outputMode !== count.\n              Received countWeights=${kwargs['countWeights']}`);\n          }\n\n          countWeights\n            =  getExactlyOneTensor(kwargs['countWeights']) as Tensor1D|Tensor2D;\n        }\n\n        const maxValue = max(inputs);\n        const minValue = min(inputs);\n        const greaterEqualMax = greater(this.numTokens, maxValue)\n                                                    .bufferSync().get(0);\n\n        const greaterMin = greaterEqual(minValue, 0).bufferSync().get(0);\n\n        if(!(greaterEqualMax && greaterMin)) {\n\n          throw new ValueError('Input values must be between 0 < values <='\n            + ` numTokens with numTokens=${this.numTokens}`);\n        }\n\n        return utils.encodeCategoricalInputs(inputs,\n          this.outputMode, this.numTokens, countWeights);\n    });\n  }\n}\n\nserialization.registerClass(CategoryEncoding);\n"]}