UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

120 lines 19.5 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Start End Packer implementation based on `tf.layers.Layer`. */ /* Original source: keras-nlp/start_end_packer.py */ import { Tensor, concat, serialization, stack, tensor, tidy } from '@tensorflow/tfjs-core'; import { Layer } from '../../../engine/topology'; import { ValueError } from '../../../errors'; /** * Adds start and end tokens to a sequence and pads to a fixed length. * * This layer is useful when tokenizing inputs for tasks like translation, * where each sequence should include a start and end marker. It should * be called after tokenization. The layer will first trim inputs to fit, then * add start/end tokens, and finally pad, if necessary, to `sequence_length`. * * Input should be either a `tf.Tensor[]` or a dense `tf.Tensor`, and * either rank-1 or rank-2. */ class StartEndPacker extends Layer { constructor(args) { super(args); this.sequenceLength = args.sequenceLength; this.startValue = args.startValue; this.endValue = args.endValue; this.padValue = args.padValue; } call(inputs, kwargs = { addStartValue: true, addEndValue: true }) { return this.callAndReturnPaddingMask(inputs, kwargs)[0]; } /** * Exactly like `call` except also returns a boolean padding mask of all * locations that are filled in with the `padValue`. */ callAndReturnPaddingMask(inputs, kwargs = { addStartValue: true, addEndValue: true }) { return tidy(() => { var _a; // Add a new axis at the beginning if needed. let x = inputs instanceof Tensor ? [inputs] : inputs; const inputIs1d = inputs instanceof Tensor && inputs.rank === 1; if (x.some(t => t.rank !== 1)) { throw new ValueError('Input must either be a rank 1 Tensor or an array of rank 1 Tensors.'); } const sequenceLength = (_a = kwargs.sequenceLength) !== null && _a !== void 0 ? _a : this.sequenceLength; // Concatenate start and end tokens. if (kwargs.addStartValue && this.startValue != null) { const startTokenIdTensor = tensor([this.startValue]); x = x.map(t => concat([startTokenIdTensor, t])); } if (kwargs.addEndValue && this.endValue != null) { const endTokenIdTensor = tensor([this.endValue]); // Trim to leave room for end token. x = x.map(t => { const sliced = t.slice(0, Math.min(t.shape[0], sequenceLength - 1)); const padded = concat([sliced, endTokenIdTensor]); return padded; }); } // tf.pad does not allow padding on Tensors with dtype='string' function ensureLength(input, length, padValue) { if (padValue === undefined) { padValue = input.dtype === 'string' ? '' : 0; } if (typeof padValue === 'number') { return input.pad([[0, length - input.size]], padValue); } const strInput = input.arraySync(); if (strInput.length <= length) { const pads = Array(length - strInput.length).fill(padValue); return tensor(strInput.concat(pads)); } return tensor(strInput.slice(0, strInput.length - length)); } const paddedMask = x.map(t => { // `onesLike` not used since it does not support string tensors. const ones = tensor(Array(t.shape[0]).fill(1)); return ensureLength(ones, sequenceLength, 0).cast('bool'); }); const mask = inputIs1d ? paddedMask[0] : stack(paddedMask); const paddedTensors = x.map(t => ensureLength(t, sequenceLength, this.padValue)); const outputs = inputIs1d ? paddedTensors[0] : stack(paddedTensors); return [outputs, mask]; }); } getConfig() { const config = { sequenceLength: this.sequenceLength, startValue: this.startValue, endValue: this.endValue, padValue: this.padValue, }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } } /** @nocollapse */ StartEndPacker.className = 'StartEndPacker'; export { StartEndPacker }; serialization.registerClass(StartEndPacker); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"start_end_packer.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/preprocessing/start_end_packer.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,oDAAoD;AACpD,OAAO,EAAE,MAAM,EAAsB,MAAM,EAAE,aAAa,EAAE,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAE/G,OAAO,EAAE,KAAK,EAAa,MAAM,0BAA0B,CAAC;AAC5D,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAiD7C;;;;;;;;;;GAUG;AACH,MAAa,cAAe,SAAQ,KAAK;IASvC,YAAY,IAAwB;QAClC,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC;QAC1C,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;QAClC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;IAChC,CAAC;IAEQ,IAAI,CACX,MAAuB,EACvB,SAA8B,EAAC,aAAa,EAAE,IAAI,EAAE,WAAW,EAAE,IAAI,EAAC;QAEtE,OAAO,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;IAC1D,CAAC;IAED;;;OAGG;IACH,wBAAwB,CACtB,MAAuB,EACvB,SAA8B,EAAC,aAAa,EAAE,IAAI,EAAE,WAAW,EAAE,IAAI,EAAC;QAEtE,OAAO,IAAI,CAAC,GAAG,EAAE;;YACf,6CAA6C;YAC7C,IAAI,CAAC,GAAG,MAAM,YAAY,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YAErD,MAAM,SAAS,GAAG,MAAM,YAAY,MAAM,IAAI,MAAM,CAAC,IAAI,KAAK,CAAC,CAAC;YAEhE,IAAI,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,KAAK,CAAC,CAAC,EAAE;gBAC7B,MAAM,IAAI,UAAU,CAClB,qEAAqE,CACtE,CAAC;aACH;YACD,MAAM,cAAc,GAAG,MAAA,MAAM,CAAC,cAAc,mCAAI,IAAI,CAAC,cAAc,CAAC;YAEpE,oCAAoC;YACpC,IAAI,MAAM,CAAC,aAAa,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;gBACnD,MAAM,kBAAkB,GAAG,MAAM,CAAC,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC;gBACrD,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,kBAAkB,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;aACjD;YACD,IAAI,MAAM,CAAC,WAAW,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;gBAC/C,MAAM,gBAAgB,GAAG,MAAM,CAAC,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;gBACjD,oCAAoC;gBACpC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE;oBACZ,MAAM,MAAM,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,cAAc,GAAG,CAAC,CAAC,CAAC,CAAC;oBACpE,MAAM,MAAM,GAAG,MAAM,CAAC,CAAC,MAAM,EAAE,gBAAgB,CAAC,CAAC,CAAC;oBAClD,OAAO,MAAM,CAAC;gBAChB,CAAC,CAAC,CAAC;aACJ;YAED,+DAA+D;YAC/D,SAAS,YAAY,CACnB,KAAa,EAAE,MAAc,EAAE,QAAwB;gBACvD,IAAI,QAAQ,KAAK,SAAS,EAAE;oBAC1B,QAAQ,GAAG,KAAK,CAAC,KAAK,KAAK,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;iBAC9C;gBACD,IAAI,OAAO,QAAQ,KAAK,QAAQ,EAAE;oBAChC,OAAO,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,EAAE,MAAM,GAAG,KAAK,CAAC,IAAI,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;iBACxD;gBAED,MAAM,QAAQ,GAAG,KAAK,CAAC,SAAS,EAAyB,CAAC;gBAE1D,IAAI,QAAQ,CAAC,MAAM,IAAI,MAAM,EAAE;oBAC7B,MAAM,IAAI,GAAG,KAAK,CAAC,MAAM,GAAG,QAAQ,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;oBAC5D,OAAO,MAAM,CAAC,QAAQ,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC;iBACtC;gBAED,OAAO,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,EAAE,QAAQ,CAAC,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC;YAC7D,CAAC;YAED,MAAM,UAAU,GAAa,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE;gBACrC,gEAAgE;gBAChE,MAAM,IAAI,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC/C,OAAO,YAAY,CAAC,IAAI,EAAE,cAAc,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;YAC5D,CAAC,CAAC,CAAC;YACH,MAAM,IAAI,GAAG,SAAS,CAAC,CAAC;gBACtB,UAAU,CAAC,CAAC,CAAa;gBACzB,CAAC,CAAC,KAAK,CAAC,UAAU,CAAa,CAAC;YAElC,MAAM,aAAa,GACjB,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,YAAY,CAAC,CAAC,EAAE,cAAc,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAC7D,MAAM,OAAO,GAAG,SAAS,CAAC,CAAC;gBACzB,aAAa,CAAC,CAAC,CAAa;gBAC5B,CAAC,CAAC,KAAK,CAAC,aAAa,CAAa,CAAC;YAErC,OAAO,CAAC,OAAO,EAAE,IAAI,CAAC,CAAC;QACzB,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;SACxB,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA7GD,kBAAkB;AACF,wBAAS,GAAG,gBAAgB,CAAC;SAFlC,cAAc;AAgH3B,aAAa,CAAC,aAAa,CAAC,cAAc,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Start End Packer implementation based on `tf.layers.Layer`.\n */\n\n/* Original source: keras-nlp/start_end_packer.py */\nimport { Tensor, Tensor1D, Tensor2D, concat, serialization, stack, tensor, tidy } from '@tensorflow/tfjs-core';\n\nimport { Layer, LayerArgs } from '../../../engine/topology';\nimport { ValueError } from '../../../errors';\n\nexport declare interface StartEndPackerArgs extends LayerArgs {\n  /**\n   * Integer. The desired output length.\n   */\n  sequenceLength: number;\n\n  /**\n   * Integer or string. The ID or token that is to be placed at the start of\n   * each sequence. The dtype must match the dtype of the input tensors to the\n   * layer. If undefined, no start value will be added.\n   */\n  startValue?: number|string;\n\n  /**\n   * Integer or string. The ID or token that is to be placed at the end of each\n   * input segment. The dtype must match the dtype of the input tensors to the\n   * layer. If undefined, no end value will be added.\n   */\n  endValue?: number|string;\n\n  /**\n   * Integer or string. The ID or token that is to be placed into the unused\n   * positions after the last segment in the sequence. If undefined, 0 or ''\n   * will be added depending on the dtype of the input tensor.\n   */\n  padValue?: number|string;\n}\n\nexport declare interface StartEndPackerOptions {\n  /**\n   * Pass to override the configured `sequenceLength` of the layer.\n   */\n  sequenceLength?: number;\n\n  /**\n   * Pass `false` to not append a start value for this input.\n   * Defaults to true.\n   */\n  addStartValue?: boolean;\n\n  /**\n   * Pass `false` to not append an end value for this input.\n   * Defaults to true.\n   */\n  addEndValue?: boolean;\n}\n\n/**\n * Adds start and end tokens to a sequence and pads to a fixed length.\n *\n *  This layer is useful when tokenizing inputs for tasks like translation,\n *  where each sequence should include a start and end marker. It should\n *  be called after tokenization. The layer will first trim inputs to fit, then\n *  add start/end tokens, and finally pad, if necessary, to `sequence_length`.\n *\n *  Input should be either a `tf.Tensor[]` or a dense `tf.Tensor`, and\n *  either rank-1 or rank-2.\n */\nexport class StartEndPacker extends Layer {\n  /** @nocollapse */\n  static readonly className = 'StartEndPacker';\n\n  private sequenceLength: number;\n  private startValue?: number|string;\n  private endValue?: number|string;\n  private padValue?: number|string;\n\n  constructor(args: StartEndPackerArgs) {\n    super(args);\n\n    this.sequenceLength = args.sequenceLength;\n    this.startValue = args.startValue;\n    this.endValue = args.endValue;\n    this.padValue = args.padValue;\n  }\n\n  override call(\n    inputs: Tensor|Tensor[],\n    kwargs: StartEndPackerOptions={addStartValue: true, addEndValue: true}\n  ): Tensor|Tensor2D {\n    return this.callAndReturnPaddingMask(inputs, kwargs)[0];\n  }\n\n  /**\n   * Exactly like `call` except also returns a boolean padding mask of all\n   * locations that are filled in with the `padValue`.\n   */\n  callAndReturnPaddingMask(\n    inputs: Tensor|Tensor[],\n    kwargs: StartEndPackerOptions={addStartValue: true, addEndValue: true}\n  ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D] {\n    return tidy(() => {\n      // Add a new axis at the beginning if needed.\n      let x = inputs instanceof Tensor ? [inputs] : inputs;\n\n      const inputIs1d = inputs instanceof Tensor && inputs.rank === 1;\n\n      if (x.some(t => t.rank !== 1)) {\n        throw new ValueError(\n          'Input must either be a rank 1 Tensor or an array of rank 1 Tensors.'\n        );\n      }\n      const sequenceLength = kwargs.sequenceLength ?? this.sequenceLength;\n\n      // Concatenate start and end tokens.\n      if (kwargs.addStartValue && this.startValue != null) {\n        const startTokenIdTensor = tensor([this.startValue]);\n        x = x.map(t => concat([startTokenIdTensor, t]));\n      }\n      if (kwargs.addEndValue && this.endValue != null) {\n        const endTokenIdTensor = tensor([this.endValue]);\n        // Trim to leave room for end token.\n        x = x.map(t => {\n          const sliced = t.slice(0, Math.min(t.shape[0], sequenceLength - 1));\n          const padded = concat([sliced, endTokenIdTensor]);\n          return padded;\n        });\n      }\n\n      // tf.pad does not allow padding on Tensors with dtype='string'\n      function ensureLength(\n        input: Tensor, length: number, padValue?: string|number) {\n        if (padValue === undefined) {\n          padValue = input.dtype === 'string' ? '' : 0;\n        }\n        if (typeof padValue === 'number') {\n          return input.pad([[0, length - input.size]], padValue);\n        }\n\n        const strInput = input.arraySync() as unknown as string[];\n\n        if (strInput.length <= length) {\n          const pads = Array(length - strInput.length).fill(padValue);\n          return tensor(strInput.concat(pads));\n        }\n\n        return tensor(strInput.slice(0, strInput.length - length));\n      }\n\n      const paddedMask: Tensor[] = x.map(t => {\n        // `onesLike` not used since it does not support string tensors.\n        const ones = tensor(Array(t.shape[0]).fill(1));\n        return ensureLength(ones, sequenceLength, 0).cast('bool');\n      });\n      const mask = inputIs1d ?\n        paddedMask[0] as Tensor1D\n        : stack(paddedMask) as Tensor2D;\n\n      const paddedTensors: Tensor[] =\n        x.map(t => ensureLength(t, sequenceLength, this.padValue));\n      const outputs = inputIs1d ?\n        paddedTensors[0] as Tensor1D\n        : stack(paddedTensors) as Tensor2D;\n\n      return [outputs, mask];\n    });\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      sequenceLength: this.sequenceLength,\n      startValue: this.startValue,\n      endValue: this.endValue,\n      padValue: this.padValue,\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(StartEndPacker);\n"]}