UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

242 lines 33.9 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { concat, keep, reshape, scalar, slice, stack, tensor, tidy, unstack } from '@tensorflow/tfjs-core'; import { assertShapesMatchAllowUndefinedSize } from './tensor_utils'; /** * The TensorArray object keeps an array of Tensors. It * allows reading from the array and writing to the array. */ export class TensorArray { constructor(name, dtype, maxSize, elementShape, identicalElementShapes, dynamicSize, clearAfterRead) { this.name = name; this.dtype = dtype; this.maxSize = maxSize; this.elementShape = elementShape; this.identicalElementShapes = identicalElementShapes; this.dynamicSize = dynamicSize; this.clearAfterRead = clearAfterRead; this.tensors = []; this.closed_ = false; this.idTensor = scalar(0); keep(this.idTensor); } get id() { return this.idTensor.id; } get closed() { return this.closed_; } /** * Dispose the tensors and idTensor and mark the TensoryArray as closed. */ clearAndClose(keepIds) { this.tensors.forEach(tensor => { if (keepIds == null || !keepIds.has(tensor.tensor.id)) { tensor.tensor.dispose(); } }); this.tensors = []; this.closed_ = true; this.idTensor.dispose(); } size() { return this.tensors.length; } /** * Read the value at location index in the TensorArray. * @param index Number the index to read from. */ read(index) { if (this.closed_) { throw new Error(`TensorArray ${this.name} has already been closed.`); } if (index < 0 || index >= this.size()) { throw new Error(`Tried to read from index ${index}, but array size is: ${this.size()}`); } const tensorWithState = this.tensors[index]; if (tensorWithState.cleared) { throw new Error(`TensorArray ${this.name}: Could not read index ${index} twice because it was cleared after a previous read ` + `(perhaps try setting clear_after_read = false?).`); } if (this.clearAfterRead) { tensorWithState.cleared = true; } tensorWithState.read = true; return tensorWithState.tensor; } /** * Helper method to read multiple tensors from the specified indices. */ readMany(indices) { return indices.map(index => this.read(index)); } /** * Write value into the index of the TensorArray. * @param index number the index to write to. * @param tensor */ write(index, tensor) { if (this.closed_) { throw new Error(`TensorArray ${this.name} has already been closed.`); } if (index < 0 || !this.dynamicSize && index >= this.maxSize) { throw new Error(`Tried to write to index ${index}, but array is not resizeable and size is: ${this.maxSize}`); } const t = this.tensors[index] || {}; if (tensor.dtype !== this.dtype) { throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because the value dtype is ${tensor.dtype}, but TensorArray dtype is ${this.dtype}.`); } // Set the shape for the first time write to unknow shape tensor array if (this.size() === 0 && (this.elementShape == null || this.elementShape.length === 0)) { this.elementShape = tensor.shape; } assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${index}.`); if (t.read) { throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been read.`); } if (t.written) { throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been written.`); } t.tensor = tensor; keep(tensor); t.written = true; this.tensors[index] = t; } /** * Helper method to write multiple tensors to the specified indices. */ writeMany(indices, tensors) { if (indices.length !== tensors.length) { throw new Error(`TensorArray ${this.name}: could not write multiple tensors,` + `because the index size: ${indices.length} is not the same as tensors size: ${tensors.length}.`); } indices.forEach((i, index) => this.write(i, tensors[index])); } /** * Return selected values in the TensorArray as a packed Tensor. All of * selected values must have been written and their shapes must all match. * @param [indices] number[] Optional. Taking values in [0, max_value). If the * TensorArray is not dynamic, max_value=size(). If not specified returns * all tensors in the original order. * @param [dtype] */ gather(indices, dtype) { if (!!dtype && dtype !== this.dtype) { throw new Error(`TensorArray dtype is ${this.dtype} but gather requested dtype ${dtype}`); } if (!indices) { indices = []; for (let i = 0; i < this.size(); i++) { indices.push(i); } } else { indices = indices.slice(0, this.size()); } if (indices.length === 0) { return tensor([], [0].concat(this.elementShape)); } // Read all the PersistentTensors into a vector to keep track of // their memory. const tensors = this.readMany(indices); assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: '); return stack(tensors, 0); } /** * Return the values in the TensorArray as a concatenated Tensor. */ concat(dtype) { if (!!dtype && dtype !== this.dtype) { throw new Error(`TensorArray dtype is ${this.dtype} but concat requested dtype ${dtype}`); } if (this.size() === 0) { return tensor([], [0].concat(this.elementShape)); } const indices = []; for (let i = 0; i < this.size(); i++) { indices.push(i); } // Collect all the tensors from the tensors array. const tensors = this.readMany(indices); assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, `TensorArray shape mismatch: tensor array shape (${this.elementShape}) vs first tensor shape (${tensors[0].shape})`); return concat(tensors, 0); } /** * Scatter the values of a Tensor in specific indices of a TensorArray. * @param indices number[] values in [0, max_value). If the * TensorArray is not dynamic, max_value=size(). * @param tensor Tensor input tensor. */ scatter(indices, tensor) { if (tensor.dtype !== this.dtype) { throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`); } if (indices.length !== tensor.shape[0]) { throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`); } const maxIndex = Math.max(...indices); if (!this.dynamicSize && maxIndex >= this.maxSize) { throw new Error(`Max index must be < array size (${maxIndex} vs. ${this.maxSize})`); } this.writeMany(indices, unstack(tensor, 0)); } /** * Split the values of a Tensor into the TensorArray. * @param length number[] with the lengths to use when splitting value along * its first dimension. * @param tensor Tensor, the tensor to split. */ split(length, tensor) { if (tensor.dtype !== this.dtype) { throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`); } let totalLength = 0; const cumulativeLengths = length.map(len => { totalLength += len; return totalLength; }); if (totalLength !== tensor.shape[0]) { throw new Error(`Expected sum of lengths to be equal to tensor.shape[0], but sum of lengths is ${totalLength}, and tensor's shape is: ${tensor.shape}`); } if (!this.dynamicSize && length.length !== this.maxSize) { throw new Error(`TensorArray's size is not equal to the size of lengths (${this.maxSize} vs. ${length.length}), ` + 'and the TensorArray is not marked as dynamically resizeable'); } const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength; const tensors = []; tidy(() => { tensor = reshape(tensor, [1, totalLength, elementPerRow]); for (let i = 0; i < length.length; ++i) { const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1]; const indices = [0, previousLength, 0]; const sizes = [1, length[i], elementPerRow]; tensors[i] = reshape(slice(tensor, indices, sizes), this.elementShape); } return tensors; }); const indices = []; for (let i = 0; i < length.length; i++) { indices[i] = i; } this.writeMany(indices, tensors); } } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"tensor_array.js","sourceRoot":"","sources":["../../../../../../tfjs-converter/src/executor/tensor_array.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAY,IAAI,EAAE,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,EAAE,IAAI,EAAE,OAAO,EAAC,MAAM,uBAAuB,CAAC;AAE3H,OAAO,EAAC,mCAAmC,EAAC,MAAM,gBAAgB,CAAC;AAQnE;;;GAGG;AACH,MAAM,OAAO,WAAW;IAItB,YACa,IAAY,EAAW,KAAe,EAAU,OAAe,EAChE,YAAsB,EAAW,sBAA+B,EAC/D,WAAoB,EAAW,cAAuB;QAFtD,SAAI,GAAJ,IAAI,CAAQ;QAAW,UAAK,GAAL,KAAK,CAAU;QAAU,YAAO,GAAP,OAAO,CAAQ;QAChE,iBAAY,GAAZ,YAAY,CAAU;QAAW,2BAAsB,GAAtB,sBAAsB,CAAS;QAC/D,gBAAW,GAAX,WAAW,CAAS;QAAW,mBAAc,GAAd,cAAc,CAAS;QAN3D,YAAO,GAAsB,EAAE,CAAC;QAChC,YAAO,GAAG,KAAK,CAAC;QAMtB,IAAI,CAAC,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QAC1B,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;IACtB,CAAC;IAED,IAAI,EAAE;QACJ,OAAO,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC;IAC1B,CAAC;IAED,IAAI,MAAM;QACR,OAAO,IAAI,CAAC,OAAO,CAAC;IACtB,CAAC;IAED;;OAEG;IACH,aAAa,CAAC,OAAqB;QACjC,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC5B,IAAI,OAAO,IAAI,IAAI,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE;gBACrD,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC;aACzB;QACH,CAAC,CAAC,CAAC;QACH,IAAI,CAAC,OAAO,GAAG,EAAE,CAAC;QAClB,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC;QACpB,IAAI,CAAC,QAAQ,CAAC,OAAO,EAAE,CAAC;IAC1B,CAAC;IAED,IAAI;QACF,OAAO,IAAI,CAAC,OAAO,CAAC,MAAM,CAAC;IAC7B,CAAC;IAED;;;OAGG;IACH,IAAI,CAAC,KAAa;QAChB,IAAI,IAAI,CAAC,OAAO,EAAE;YAChB,MAAM,IAAI,KAAK,CAAC,eAAe,IAAI,CAAC,IAAI,2BAA2B,CAAC,CAAC;SACtE;QAED,IAAI,KAAK,GAAG,CAAC,IAAI,KAAK,IAAI,IAAI,CAAC,IAAI,EAAE,EAAE;YACrC,MAAM,IAAI,KAAK,CAAC,4BAA4B,KAAK,wBAC7C,IAAI,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC;SACpB;QAED,MAAM,eAAe,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;QAC5C,IAAI,eAAe,CAAC,OAAO,EAAE;YAC3B,MAAM,IAAI,KAAK,CACX,eAAe,IAAI,CAAC,IAAI,0BACpB,KAAK,sDAAsD;gBAC/D,kDAAkD,CAAC,CAAC;SACzD;QAED,IAAI,IAAI,CAAC,cAAc,EAAE;YACvB,eAAe,CAAC,OAAO,GAAG,IAAI,CAAC;SAChC;QAED,eAAe,CAAC,IAAI,GAAG,IAAI,CAAC;QAC5B,OAAO,eAAe,CAAC,MAAM,CAAC;IAChC,CAAC;IAED;;OAEG;IACH,QAAQ,CAAC,OAAiB;QACxB,OAAO,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;IAChD,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,KAAa,EAAE,MAAc;QACjC,IAAI,IAAI,CAAC,OAAO,EAAE;YAChB,MAAM,IAAI,KAAK,CAAC,eAAe,IAAI,CAAC,IAAI,2BAA2B,CAAC,CAAC;SACtE;QAED,IAAI,KAAK,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,WAAW,IAAI,KAAK,IAAI,IAAI,CAAC,OAAO,EAAE;YAC3D,MAAM,IAAI,KAAK,CAAC,2BACZ,KAAK,8CAA8C,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;SACxE;QAED,MAAM,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,EAAE,CAAC;QAEpC,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YAC/B,MAAM,IAAI,KAAK,CAAC,eACZ,IAAI,CAAC,IAAI,0CAA0C,KAAK;uCAExD,MAAM,CAAC,KAAK,8BAA8B,IAAI,CAAC,KAAK,GAAG,CAAC,CAAC;SAC9D;QAED,sEAAsE;QACtE,IAAI,IAAI,CAAC,IAAI,EAAE,KAAK,CAAC;YACjB,CAAC,IAAI,CAAC,YAAY,IAAI,IAAI,IAAI,IAAI,CAAC,YAAY,CAAC,MAAM,KAAK,CAAC,CAAC,EAAE;YACjE,IAAI,CAAC,YAAY,GAAG,MAAM,CAAC,KAAK,CAAC;SAClC;QAED,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,MAAM,CAAC,KAAK,EAC/B,eAAe,IAAI,CAAC,IAAI,0CACpB,KAAK,GAAG,CAAC,CAAC;QAElB,IAAI,CAAC,CAAC,IAAI,EAAE;YACV,MAAM,IAAI,KAAK,CACX,eAAe,IAAI,CAAC,IAAI,0CACpB,KAAK,qCAAqC,CAAC,CAAC;SACrD;QAED,IAAI,CAAC,CAAC,OAAO,EAAE;YACb,MAAM,IAAI,KAAK,CACX,eAAe,IAAI,CAAC,IAAI,0CACpB,KAAK,wCAAwC,CAAC,CAAC;SACxD;QAED,CAAC,CAAC,MAAM,GAAG,MAAM,CAAC;QAClB,IAAI,CAAC,MAAM,CAAC,CAAC;QACb,CAAC,CAAC,OAAO,GAAG,IAAI,CAAC;QAEjB,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC;IAC1B,CAAC;IAED;;OAEG;IACH,SAAS,CAAC,OAAiB,EAAE,OAAiB;QAC5C,IAAI,OAAO,CAAC,MAAM,KAAK,OAAO,CAAC,MAAM,EAAE;YACrC,MAAM,IAAI,KAAK,CACX,eAAe,IAAI,CAAC,IAAI,qCAAqC;gBAC7D,2BACI,OAAO,CAAC,MAAM,qCACd,OAAO,CAAC,MAAM,GAAG,CAAC,CAAC;SAC5B;QAED,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,KAAK,EAAE,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAC/D,CAAC;IAED;;;;;;;OAOG;IACH,MAAM,CAAC,OAAkB,EAAE,KAAgB;QACzC,IAAI,CAAC,CAAC,KAAK,IAAI,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YACnC,MAAM,IAAI,KAAK,CAAC,wBACZ,IAAI,CAAC,KAAK,+BAA+B,KAAK,EAAE,CAAC,CAAC;SACvD;QAED,IAAI,CAAC,OAAO,EAAE;YACZ,OAAO,GAAG,EAAE,CAAC;YACb,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE;gBACpC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;aACjB;SACF;aAAM;YACL,OAAO,GAAG,OAAO,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;SACzC;QAED,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;YACxB,OAAO,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC,CAAC;SAClD;QAED,gEAAgE;QAChE,gBAAgB;QAChB,MAAM,OAAO,GAAG,IAAI,CAAC,QAAQ,CAAC,OAAO,CAAC,CAAC;QAEvC,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC,CAAC,CAAC,KAAK,EAAE,8BAA8B,CAAC,CAAC;QAEzE,OAAO,KAAK,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC;IAC3B,CAAC;IAED;;OAEG;IACH,MAAM,CAAC,KAAgB;QACrB,IAAI,CAAC,CAAC,KAAK,IAAI,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YACnC,MAAM,IAAI,KAAK,CAAC,wBACZ,IAAI,CAAC,KAAK,+BAA+B,KAAK,EAAE,CAAC,CAAC;SACvD;QAED,IAAI,IAAI,CAAC,IAAI,EAAE,KAAK,CAAC,EAAE;YACrB,OAAO,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC,CAAC;SAClD;QAED,MAAM,OAAO,GAAG,EAAE,CAAC;QACnB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE;YACpC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACjB;QACD,kDAAkD;QAClD,MAAM,OAAO,GAAG,IAAI,CAAC,QAAQ,CAAC,OAAO,CAAC,CAAC;QAEvC,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC,CAAC,CAAC,KAAK,EACnC,mDACI,IAAI,CAAC,YAAY,4BAA4B,OAAO,CAAC,CAAC,CAAC,CAAC,KAAK,GAAG,CAAC,CAAC;QAE1E,OAAO,MAAM,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC;IAC5B,CAAC;IAED;;;;;OAKG;IACH,OAAO,CAAC,OAAiB,EAAE,MAAc;QACvC,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YAC/B,MAAM,IAAI,KAAK,CAAC,wBACZ,IAAI,CAAC,KAAK,yBAAyB,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;SACxD;QAED,IAAI,OAAO,CAAC,MAAM,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,sDACZ,OAAO,CAAC,MAAM,QAAQ,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;SAC9C;QAED,MAAM,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,OAAO,CAAC,CAAC;QAEtC,IAAI,CAAC,IAAI,CAAC,WAAW,IAAI,QAAQ,IAAI,IAAI,CAAC,OAAO,EAAE;YACjD,MAAM,IAAI,KAAK,CACX,mCAAmC,QAAQ,SAAS,IAAI,CAAC,OAAO,GAAG,CAAC,CAAC;SAC1E;QAED,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC;IAC9C,CAAC;IAED;;;;;OAKG;IACH,KAAK,CAAC,MAAgB,EAAE,MAAc;QACpC,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YAC/B,MAAM,IAAI,KAAK,CAAC,wBACZ,IAAI,CAAC,KAAK,yBAAyB,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;SACxD;QACD,IAAI,WAAW,GAAG,CAAC,CAAC;QACpB,MAAM,iBAAiB,GAAG,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;YACzC,WAAW,IAAI,GAAG,CAAC;YACnB,OAAO,WAAW,CAAC;QACrB,CAAC,CAAC,CAAC;QAEH,IAAI,WAAW,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;YACnC,MAAM,IAAI,KAAK,CAAC;;UAEZ,WAAW,4BAA4B,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;SAC5D;QAED,IAAI,CAAC,IAAI,CAAC,WAAW,IAAI,MAAM,CAAC,MAAM,KAAK,IAAI,CAAC,OAAO,EAAE;YACvD,MAAM,IAAI,KAAK,CACX,2DACI,IAAI,CAAC,OAAO,QAAQ,MAAM,CAAC,MAAM,KAAK;gBAC1C,6DAA6D,CAAC,CAAC;SACpE;QAED,MAAM,aAAa,GAAG,WAAW,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,GAAG,WAAW,CAAC;QACxE,MAAM,OAAO,GAAa,EAAE,CAAC;QAC7B,IAAI,CAAC,GAAG,EAAE;YACR,MAAM,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,EAAE,WAAW,EAAE,aAAa,CAAC,CAAC,CAAC;YAC1D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBACtC,MAAM,cAAc,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,iBAAiB,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBAChE,MAAM,OAAO,GAAG,CAAC,CAAC,EAAE,cAAc,EAAE,CAAC,CAAC,CAAC;gBACvC,MAAM,KAAK,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,aAAa,CAAC,CAAC;gBAC5C,OAAO,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,KAAK,CAAC,MAAM,EAAE,OAAO,EAAE,KAAK,CAAC,EAAE,IAAI,CAAC,YAAY,CAAC,CAAC;aACxE;YACD,OAAO,OAAO,CAAC;QACjB,CAAC,CAAC,CAAC;QACH,MAAM,OAAO,GAAG,EAAE,CAAC;QACnB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;YACtC,OAAO,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;SAChB;QACD,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;IACnC,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {concat, DataType, keep, reshape, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core';\n\nimport {assertShapesMatchAllowUndefinedSize} from './tensor_utils';\n\nexport interface TensorWithState {\n  tensor?: Tensor;\n  written?: boolean;\n  read?: boolean;\n  cleared?: boolean;\n}\n/**\n * The TensorArray object keeps an array of Tensors.  It\n * allows reading from the array and writing to the array.\n */\nexport class TensorArray {\n  private tensors: TensorWithState[] = [];\n  private closed_ = false;\n  readonly idTensor: Tensor;\n  constructor(\n      readonly name: string, readonly dtype: DataType, private maxSize: number,\n      private elementShape: number[], readonly identicalElementShapes: boolean,\n      readonly dynamicSize: boolean, readonly clearAfterRead: boolean) {\n    this.idTensor = scalar(0);\n    keep(this.idTensor);\n  }\n\n  get id() {\n    return this.idTensor.id;\n  }\n\n  get closed() {\n    return this.closed_;\n  }\n\n  /**\n   * Dispose the tensors and idTensor and mark the TensoryArray as closed.\n   */\n  clearAndClose(keepIds?: Set<number>) {\n    this.tensors.forEach(tensor => {\n      if (keepIds == null || !keepIds.has(tensor.tensor.id)) {\n        tensor.tensor.dispose();\n      }\n    });\n    this.tensors = [];\n    this.closed_ = true;\n    this.idTensor.dispose();\n  }\n\n  size(): number {\n    return this.tensors.length;\n  }\n\n  /**\n   * Read the value at location index in the TensorArray.\n   * @param index Number the index to read from.\n   */\n  read(index: number): Tensor {\n    if (this.closed_) {\n      throw new Error(`TensorArray ${this.name} has already been closed.`);\n    }\n\n    if (index < 0 || index >= this.size()) {\n      throw new Error(`Tried to read from index ${index}, but array size is: ${\n          this.size()}`);\n    }\n\n    const tensorWithState = this.tensors[index];\n    if (tensorWithState.cleared) {\n      throw new Error(\n          `TensorArray ${this.name}: Could not read index ${\n              index} twice because it was cleared after a previous read ` +\n          `(perhaps try setting clear_after_read = false?).`);\n    }\n\n    if (this.clearAfterRead) {\n      tensorWithState.cleared = true;\n    }\n\n    tensorWithState.read = true;\n    return tensorWithState.tensor;\n  }\n\n  /**\n   * Helper method to read multiple tensors from the specified indices.\n   */\n  readMany(indices: number[]): Tensor[] {\n    return indices.map(index => this.read(index));\n  }\n\n  /**\n   * Write value into the index of the TensorArray.\n   * @param index number the index to write to.\n   * @param tensor\n   */\n  write(index: number, tensor: Tensor) {\n    if (this.closed_) {\n      throw new Error(`TensorArray ${this.name} has already been closed.`);\n    }\n\n    if (index < 0 || !this.dynamicSize && index >= this.maxSize) {\n      throw new Error(`Tried to write to index ${\n          index}, but array is not resizeable and size is: ${this.maxSize}`);\n    }\n\n    const t = this.tensors[index] || {};\n\n    if (tensor.dtype !== this.dtype) {\n      throw new Error(`TensorArray ${\n          this.name}: Could not write to TensorArray index ${index},\n          because the value dtype is ${\n          tensor.dtype}, but TensorArray dtype is ${this.dtype}.`);\n    }\n\n    // Set the shape for the first time write to unknow shape tensor array\n    if (this.size() === 0 &&\n        (this.elementShape == null || this.elementShape.length === 0)) {\n      this.elementShape = tensor.shape;\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, tensor.shape,\n        `TensorArray ${this.name}: Could not write to TensorArray index ${\n            index}.`);\n\n    if (t.read) {\n      throw new Error(\n          `TensorArray ${this.name}: Could not write to TensorArray index ${\n              index}, because it has already been read.`);\n    }\n\n    if (t.written) {\n      throw new Error(\n          `TensorArray ${this.name}: Could not write to TensorArray index ${\n              index}, because it has already been written.`);\n    }\n\n    t.tensor = tensor;\n    keep(tensor);\n    t.written = true;\n\n    this.tensors[index] = t;\n  }\n\n  /**\n   * Helper method to write multiple tensors to the specified indices.\n   */\n  writeMany(indices: number[], tensors: Tensor[]) {\n    if (indices.length !== tensors.length) {\n      throw new Error(\n          `TensorArray ${this.name}: could not write multiple tensors,` +\n          `because the index size: ${\n              indices.length} is not the same as tensors size: ${\n              tensors.length}.`);\n    }\n\n    indices.forEach((i, index) => this.write(i, tensors[index]));\n  }\n\n  /**\n   * Return selected values in the TensorArray as a packed Tensor. All of\n   * selected values must have been written and their shapes must all match.\n   * @param [indices] number[] Optional. Taking values in [0, max_value). If the\n   *    TensorArray is not dynamic, max_value=size(). If not specified returns\n   *    all tensors in the original order.\n   * @param [dtype]\n   */\n  gather(indices?: number[], dtype?: DataType): Tensor {\n    if (!!dtype && dtype !== this.dtype) {\n      throw new Error(`TensorArray dtype is ${\n          this.dtype} but gather requested dtype ${dtype}`);\n    }\n\n    if (!indices) {\n      indices = [];\n      for (let i = 0; i < this.size(); i++) {\n        indices.push(i);\n      }\n    } else {\n      indices = indices.slice(0, this.size());\n    }\n\n    if (indices.length === 0) {\n      return tensor([], [0].concat(this.elementShape));\n    }\n\n    // Read all the PersistentTensors into a vector to keep track of\n    // their memory.\n    const tensors = this.readMany(indices);\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: ');\n\n    return stack(tensors, 0);\n  }\n\n  /**\n   * Return the values in the TensorArray as a concatenated Tensor.\n   */\n  concat(dtype?: DataType): Tensor {\n    if (!!dtype && dtype !== this.dtype) {\n      throw new Error(`TensorArray dtype is ${\n          this.dtype} but concat requested dtype ${dtype}`);\n    }\n\n    if (this.size() === 0) {\n      return tensor([], [0].concat(this.elementShape));\n    }\n\n    const indices = [];\n    for (let i = 0; i < this.size(); i++) {\n      indices.push(i);\n    }\n    // Collect all the tensors from the tensors array.\n    const tensors = this.readMany(indices);\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, tensors[0].shape,\n        `TensorArray shape mismatch: tensor array shape (${\n            this.elementShape}) vs first tensor shape (${tensors[0].shape})`);\n\n    return concat(tensors, 0);\n  }\n\n  /**\n   * Scatter the values of a Tensor in specific indices of a TensorArray.\n   * @param indices number[] values in [0, max_value). If the\n   *    TensorArray is not dynamic, max_value=size().\n   * @param tensor Tensor input tensor.\n   */\n  scatter(indices: number[], tensor: Tensor) {\n    if (tensor.dtype !== this.dtype) {\n      throw new Error(`TensorArray dtype is ${\n          this.dtype} but tensor has dtype ${tensor.dtype}`);\n    }\n\n    if (indices.length !== tensor.shape[0]) {\n      throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${\n          indices.length} vs. ${tensor.shape[0]}`);\n    }\n\n    const maxIndex = Math.max(...indices);\n\n    if (!this.dynamicSize && maxIndex >= this.maxSize) {\n      throw new Error(\n          `Max index must be < array size (${maxIndex}  vs. ${this.maxSize})`);\n    }\n\n    this.writeMany(indices, unstack(tensor, 0));\n  }\n\n  /**\n   * Split the values of a Tensor into the TensorArray.\n   * @param length number[] with the lengths to use when splitting value along\n   *    its first dimension.\n   * @param tensor Tensor, the tensor to split.\n   */\n  split(length: number[], tensor: Tensor) {\n    if (tensor.dtype !== this.dtype) {\n      throw new Error(`TensorArray dtype is ${\n          this.dtype} but tensor has dtype ${tensor.dtype}`);\n    }\n    let totalLength = 0;\n    const cumulativeLengths = length.map(len => {\n      totalLength += len;\n      return totalLength;\n    });\n\n    if (totalLength !== tensor.shape[0]) {\n      throw new Error(`Expected sum of lengths to be equal to\n          tensor.shape[0], but sum of lengths is\n        ${totalLength}, and tensor's shape is: ${tensor.shape}`);\n    }\n\n    if (!this.dynamicSize && length.length !== this.maxSize) {\n      throw new Error(\n          `TensorArray's size is not equal to the size of lengths (${\n              this.maxSize} vs. ${length.length}), ` +\n          'and the TensorArray is not marked as dynamically resizeable');\n    }\n\n    const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;\n    const tensors: Tensor[] = [];\n    tidy(() => {\n      tensor = reshape(tensor, [1, totalLength, elementPerRow]);\n      for (let i = 0; i < length.length; ++i) {\n        const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];\n        const indices = [0, previousLength, 0];\n        const sizes = [1, length[i], elementPerRow];\n        tensors[i] = reshape(slice(tensor, indices, sizes), this.elementShape);\n      }\n      return tensors;\n    });\n    const indices = [];\n    for (let i = 0; i < length.length; i++) {\n      indices[i] = i;\n    }\n    this.writeMany(indices, tensors);\n  }\n}\n"]}