UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

120 lines 19.5 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Start End Packer implementation based on `tf.layers.Layer`. */ /* Original source: keras-nlp/start_end_packer.py */ import { Tensor, concat, serialization, stack, tensor, tidy } from '@tensorflow/tfjs-core'; import { Layer } from '../../../engine/topology'; import { ValueError } from '../../../errors'; /** * Adds start and end tokens to a sequence and pads to a fixed length. * * This layer is useful when tokenizing inputs for tasks like translation, * where each sequence should include a start and end marker. It should * be called after tokenization. The layer will first trim inputs to fit, then * add start/end tokens, and finally pad, if necessary, to `sequence_length`. * * Input should be either a `tf.Tensor[]` or a dense `tf.Tensor`, and * either rank-1 or rank-2. */ class StartEndPacker extends Layer { constructor(args) { super(args); this.sequenceLength = args.sequenceLength; this.startValue = args.startValue; this.endValue = args.endValue; this.padValue = args.padValue; } call(inputs, kwargs = { addStartValue: true, addEndValue: true }) { return this.callAndReturnPaddingMask(inputs, kwargs)[0]; } /** * Exactly like `call` except also returns a boolean padding mask of all * locations that are filled in with the `padValue`. */ callAndReturnPaddingMask(inputs, kwargs = { addStartValue: true, addEndValue: true }) { return tidy(() => { var _a; // Add a new axis at the beginning if needed. let x = inputs instanceof Tensor ? [inputs] : inputs; const inputIs1d = inputs instanceof Tensor && inputs.rank === 1; if (x.some(t => t.rank !== 1)) { throw new ValueError('Input must either be a rank 1 Tensor or an array of rank 1 Tensors.'); } const sequenceLength = (_a = kwargs.sequenceLength) !== null && _a !== void 0 ? _a : this.sequenceLength; // Concatenate start and end tokens. if (kwargs.addStartValue && this.startValue != null) { const startTokenIdTensor = tensor([this.startValue]); x = x.map(t => concat([startTokenIdTensor, t])); } if (kwargs.addEndValue && this.endValue != null) { const endTokenIdTensor = tensor([this.endValue]); // Trim to leave room for end token. x = x.map(t => { const sliced = t.slice(0, Math.min(t.shape[0], sequenceLength - 1)); const padded = concat([sliced, endTokenIdTensor]); return padded; }); } // tf.pad does not allow padding on Tensors with dtype='string' function ensureLength(input, length, padValue) { if (padValue === undefined) { padValue = input.dtype === 'string' ? '' : 0; } if (typeof padValue === 'number') { return input.pad([[0, length - input.size]], padValue); } const strInput = input.arraySync(); if (strInput.length <= length) { const pads = Array(length - strInput.length).fill(padValue); return tensor(strInput.concat(pads)); } return tensor(strInput.slice(0, strInput.length - length)); } const paddedMask = x.map(t => { // `onesLike` not used since it does not support string tensors. const ones = tensor(Array(t.shape[0]).fill(1)); return ensureLength(ones, sequenceLength, 0).cast('bool'); }); const mask = inputIs1d ? paddedMask[0] : stack(paddedMask); const paddedTensors = x.map(t => ensureLength(t, sequenceLength, this.padValue)); const outputs = inputIs1d ? paddedTensors[0] : stack(paddedTensors); return [outputs, mask]; }); } getConfig() { const config = { sequenceLength: this.sequenceLength, startValue: this.startValue, endValue: this.endValue, padValue: this.padValue, }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } } /** @nocollapse */ StartEndPacker.className = 'StartEndPacker'; export { StartEndPacker }; serialization.registerClass(StartEndPacker); //# sourceMappingURL=data:application/json;base64,