@tensorflow/tfjs-converter
Version:
Tensorflow model converter for javascript
333 lines • 48.3 kB
JavaScript
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { concat, keep, reshape, scalar, slice, stack, tensor, tidy, unstack } from '@tensorflow/tfjs-core';
import { assertShapesMatchAllowUndefinedSize, inferElementShape, mergeElementShape } from './tensor_utils';
/**
* TensorList stores a container of `tf.Tensor` objects, which are accessible
* via tensors field.
*
* In order to get a copy of the underlying list, use the copy method:
* ```
* TensorList b = a.copy();
* b.tensors().pushBack(t); // This does not modify a.tensors().
* ```
*
* Note that this is not a deep copy: the memory locations of the underlying
* tensors will still point to the same locations of the corresponding tensors
* in the original.
*/
export class TensorList {
get id() {
return this.idTensor.id;
}
/**
*
* @param tensors list of tensors
* @param elementShape shape of each tensor, this can be a single number (any
* shape is allowed) or partial shape (dim = -1).
* @param elementDtype data type of each tensor
* @param maxNumElements The maximum allowed size of `tensors`. Defaults to -1
* meaning that the size of `tensors` is unbounded.
*/
constructor(tensors, elementShape, elementDtype, maxNumElements = -1) {
this.tensors = tensors;
this.elementShape = elementShape;
this.elementDtype = elementDtype;
if (tensors != null) {
tensors.forEach(tensor => {
if (elementDtype !== tensor.dtype) {
throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${tensor.dtype}`);
}
assertShapesMatchAllowUndefinedSize(elementShape, tensor.shape, 'TensorList shape mismatch: ');
keep(tensor);
});
}
this.idTensor = scalar(0);
this.maxNumElements = maxNumElements;
keep(this.idTensor);
}
/**
* Get a new TensorList containing a copy of the underlying tensor container.
*/
copy() {
return new TensorList([...this.tensors], this.elementShape, this.elementDtype);
}
/**
* Dispose the tensors and idTensor and clear the tensor list.
*/
clearAndClose(keepIds) {
this.tensors.forEach(tensor => {
if (keepIds == null || !keepIds.has(tensor.id)) {
tensor.dispose();
}
});
this.tensors.length = 0;
this.idTensor.dispose();
}
/**
* The size of the tensors in the tensor list.
*/
size() {
return this.tensors.length;
}
/**
* Return a tensor that stacks a list of rank-R tf.Tensors into one rank-(R+1)
* tf.Tensor.
* @param elementShape shape of each tensor
* @param elementDtype data type of each tensor
* @param numElements the number of elements to stack
*/
stack(elementShape, elementDtype, numElements = -1) {
if (elementDtype !== this.elementDtype) {
throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
}
if (numElements !== -1 && this.tensors.length !== numElements) {
throw new Error(`Operation expected a list with ${numElements} elements but got a list with ${this.tensors.length} elements.`);
}
assertShapesMatchAllowUndefinedSize(elementShape, this.elementShape, 'TensorList shape mismatch: ');
const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
return tidy(() => {
const reshapedTensors = this.tensors.map(tensor => reshape(tensor, outputElementShape));
return stack(reshapedTensors, 0);
});
}
/**
* Pop a tensor from the end of the list.
* @param elementShape shape of the tensor
* @param elementDtype data type of the tensor
*/
popBack(elementShape, elementDtype) {
if (elementDtype !== this.elementDtype) {
throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
}
if (this.size() === 0) {
throw new Error('Trying to pop from an empty list.');
}
const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
const tensor = this.tensors.pop();
tensor.kept = false;
assertShapesMatchAllowUndefinedSize(tensor.shape, elementShape, 'TensorList shape mismatch: ');
return reshape(tensor, outputElementShape);
}
/**
* Push a tensor to the end of the list.
* @param tensor Tensor to be pushed.
*/
pushBack(tensor) {
if (tensor.dtype !== this.elementDtype) {
throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
}
assertShapesMatchAllowUndefinedSize(tensor.shape, this.elementShape, 'TensorList shape mismatch: ');
if (this.maxNumElements === this.size()) {
throw new Error(`Trying to push element into a full list.`);
}
keep(tensor);
this.tensors.push(tensor);
}
/**
* Update the size of the list.
* @param size the new size of the list.
*/
resize(size) {
if (size < 0) {
throw new Error(`TensorListResize expects size to be non-negative. Got: ${size}`);
}
if (this.maxNumElements !== -1 && size > this.maxNumElements) {
throw new Error(`TensorListResize input size ${size} is greater maxNumElement ${this.maxNumElements}.`);
}
const destTensorList = new TensorList([], this.elementShape, this.elementDtype, this.maxNumElements);
destTensorList.tensors.length = size;
for (let i = 0; i < Math.min(this.tensors.length, size); ++i) {
destTensorList.tensors[i] = this.tensors[i];
}
return destTensorList;
}
/**
* Retrieve the element at the provided index
* @param elementShape shape of the tensor
* @param elementDtype dtype of the tensor
* @param elementIndex index of the tensor
*/
getItem(elementIndex, elementShape, elementDtype) {
if (elementDtype !== this.elementDtype) {
throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
}
if (elementIndex < 0 || elementIndex > this.tensors.length) {
throw new Error(`Trying to access element ${elementIndex} in a list with ${this.tensors.length} elements.`);
}
if (this.tensors[elementIndex] == null) {
throw new Error(`element at index ${elementIndex} is null.`);
}
assertShapesMatchAllowUndefinedSize(this.tensors[elementIndex].shape, elementShape, 'TensorList shape mismatch: ');
const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
return reshape(this.tensors[elementIndex], outputElementShape);
}
/**
* Set the tensor at the index
* @param elementIndex index of the tensor
* @param tensor the tensor to be inserted into the list
*/
setItem(elementIndex, tensor) {
if (tensor.dtype !== this.elementDtype) {
throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
}
if (elementIndex < 0 ||
this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {
throw new Error(`Trying to set element ${elementIndex} in a list with max ${this.maxNumElements} elements.`);
}
assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, 'TensorList shape mismatch: ');
keep(tensor);
// dispose the previous value if it is replacing.
if (this.tensors[elementIndex] != null) {
this.tensors[elementIndex].kept = false;
}
this.tensors[elementIndex] = tensor;
}
/**
* Return selected values in the TensorList as a stacked Tensor. All of
* selected values must have been written and their shapes must all match.
* @param indices indices of tensors to gather
* @param elementDtype output tensor dtype
* @param elementShape output tensor element shape
*/
gather(indices, elementDtype, elementShape) {
if (elementDtype !== this.elementDtype) {
throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
}
assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
// When indices is greater than the size of the list, indices beyond the
// size of the list are ignored.
indices = indices.slice(0, this.size());
const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
if (indices.length === 0) {
return tensor([], [0].concat(outputElementShape));
}
return tidy(() => {
const tensors = indices.map(i => reshape(this.tensors[i], outputElementShape));
return stack(tensors, 0);
});
}
/**
* Return the values in the TensorList as a concatenated Tensor.
* @param elementDtype output tensor dtype
* @param elementShape output tensor element shape
*/
concat(elementDtype, elementShape) {
if (!!elementDtype && elementDtype !== this.elementDtype) {
throw new Error(`TensorList dtype is ${this.elementDtype} but concat requested dtype ${elementDtype}`);
}
assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
if (this.size() === 0) {
return tensor([], [0].concat(outputElementShape));
}
return tidy(() => {
const tensors = this.tensors.map(t => reshape(t, outputElementShape));
return concat(tensors, 0);
});
}
}
/**
* Creates a TensorList which, when stacked, has the value of tensor.
* @param tensor from tensor
* @param elementShape output tensor element shape
*/
export function fromTensor(tensor, elementShape, elementDtype) {
const dtype = tensor.dtype;
if (tensor.shape.length < 1) {
throw new Error(`Tensor must be at least a vector, but saw shape: ${tensor.shape}`);
}
if (tensor.dtype !== elementDtype) {
throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${elementDtype}`);
}
const tensorElementShape = tensor.shape.slice(1);
assertShapesMatchAllowUndefinedSize(tensorElementShape, elementShape, 'TensorList shape mismatch: ');
const tensorList = unstack(tensor);
return new TensorList(tensorList, elementShape, dtype);
}
/**
* Return a TensorList of the given size with empty elements.
* @param elementShape the shape of the future elements of the list
* @param elementDtype the desired type of elements in the list
* @param numElements the number of elements to reserve
* @param maxNumElements the maximum number of elements in th list
*/
export function reserve(elementShape, elementDtype, numElements, maxNumElements) {
return new TensorList([], elementShape, elementDtype, maxNumElements);
}
/**
* Put tensors at specific indices of a stacked tensor into a TensorList.
* @param indices list of indices on how to scatter the tensor.
* @param tensor input tensor.
* @param elementShape the shape of the future elements of the list
* @param numElements the number of elements to scatter
*/
export function scatter(tensor, indices, elementShape, numElements) {
if (indices.length !== tensor.shape[0]) {
throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`);
}
const maxIndex = Math.max(...indices);
if (numElements != null && numElements !== -1 && maxIndex >= numElements) {
throw new Error(`Max index must be < array size (${maxIndex} vs. ${numElements})`);
}
const list = new TensorList([], elementShape, tensor.dtype, numElements);
const tensors = unstack(tensor, 0);
indices.forEach((value, index) => {
list.setItem(value, tensors[index]);
});
return list;
}
/**
* Split the values of a Tensor into a TensorList.
* @param length the lengths to use when splitting value along
* its first dimension.
* @param tensor the tensor to split.
* @param elementShape the shape of the future elements of the list
*/
export function split(tensor, length, elementShape) {
let totalLength = 0;
const cumulativeLengths = length.map(len => {
totalLength += len;
return totalLength;
});
if (totalLength !== tensor.shape[0]) {
throw new Error(`Expected sum of lengths to be equal to
tensor.shape[0], but sum of lengths is
${totalLength}, and tensor's shape is: ${tensor.shape}`);
}
const shapeWithoutFirstDim = tensor.shape.slice(1);
const outputElementShape = mergeElementShape(shapeWithoutFirstDim, elementShape);
const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
const tensors = tidy(() => {
const tensors = [];
tensor = reshape(tensor, [1, totalLength, elementPerRow]);
for (let i = 0; i < length.length; ++i) {
const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
const indices = [0, previousLength, 0];
const sizes = [1, length[i], elementPerRow];
tensors[i] = reshape(slice(tensor, indices, sizes), outputElementShape);
}
tensor.dispose();
return tensors;
});
const list = new TensorList([], elementShape, tensor.dtype, length.length);
for (let i = 0; i < tensors.length; i++) {
list.setItem(i, tensors[i]);
}
return list;
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"tensor_list.js","sourceRoot":"","sources":["../../../../../../tfjs-converter/src/executor/tensor_list.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAY,IAAI,EAAE,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,EAAE,IAAI,EAAE,OAAO,EAAC,MAAM,uBAAuB,CAAC;AAE3H,OAAO,EAAC,mCAAmC,EAAE,iBAAiB,EAAE,iBAAiB,EAAC,MAAM,gBAAgB,CAAC;AAEzG;;;;;;;;;;;;;GAaG;AAEH,MAAM,OAAO,UAAU;IAIrB,IAAI,EAAE;QACJ,OAAO,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC;IAC1B,CAAC;IACD;;;;;;;;OAQG;IACH,YACa,OAAiB,EAAW,YAA6B,EACzD,YAAsB,EAAE,cAAc,GAAG,CAAC,CAAC;QAD3C,YAAO,GAAP,OAAO,CAAU;QAAW,iBAAY,GAAZ,YAAY,CAAiB;QACzD,iBAAY,GAAZ,YAAY,CAAU;QACjC,IAAI,OAAO,IAAI,IAAI,EAAE;YACnB,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;gBACvB,IAAI,YAAY,KAAK,MAAM,CAAC,KAAK,EAAE;oBACjC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;iBACxD;gBACD,mCAAmC,CAC/B,YAAY,EAAE,MAAM,CAAC,KAAK,EAAE,6BAA6B,CAAC,CAAC;gBAE/D,IAAI,CAAC,MAAM,CAAC,CAAC;YACf,CAAC,CAAC,CAAC;SACJ;QACD,IAAI,CAAC,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QAC1B,IAAI,CAAC,cAAc,GAAG,cAAc,CAAC;QACrC,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;IACtB,CAAC;IAED;;OAEG;IACH,IAAI;QACF,OAAO,IAAI,UAAU,CACjB,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,YAAY,CAAC,CAAC;IAC/D,CAAC;IAED;;OAEG;IACH,aAAa,CAAC,OAAqB;QACjC,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC5B,IAAI,OAAO,IAAI,IAAI,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE;gBAC9C,MAAM,CAAC,OAAO,EAAE,CAAC;aAClB;QACH,CAAC,CAAC,CAAC;QACH,IAAI,CAAC,OAAO,CAAC,MAAM,GAAG,CAAC,CAAC;QACxB,IAAI,CAAC,QAAQ,CAAC,OAAO,EAAE,CAAC;IAC1B,CAAC;IACD;;OAEG;IACH,IAAI;QACF,OAAO,IAAI,CAAC,OAAO,CAAC,MAAM,CAAC;IAC7B,CAAC;IAED;;;;;;OAMG;IACH,KAAK,CAAC,YAAsB,EAAE,YAAsB,EAAE,WAAW,GAAG,CAAC,CAAC;QAEpE,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QACD,IAAI,WAAW,KAAK,CAAC,CAAC,IAAI,IAAI,CAAC,OAAO,CAAC,MAAM,KAAK,WAAW,EAAE;YAC7D,MAAM,IAAI,KAAK,CAAC,kCACZ,WAAW,iCACX,IAAI,CAAC,OAAO,CAAC,MAAM,YAAY,CAAC,CAAC;SACtC;QACD,mCAAmC,CAC/B,YAAY,EAAE,IAAI,CAAC,YAAY,EAAE,6BAA6B,CAAC,CAAC;QACpE,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QACrE,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,eAAe,GACjB,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,kBAAkB,CAAC,CAAC,CAAC;YACpE,OAAO,KAAK,CAAC,eAAe,EAAE,CAAC,CAAC,CAAC;QACnC,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;;OAIG;IACH,OAAO,CAAC,YAAsB,EAAE,YAAsB;QACpD,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QAED,IAAI,IAAI,CAAC,IAAI,EAAE,KAAK,CAAC,EAAE;YACrB,MAAM,IAAI,KAAK,CAAC,mCAAmC,CAAC,CAAC;SACtD;QACD,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QACrE,MAAM,MAAM,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,EAAE,CAAC;QAClC,MAAM,CAAC,IAAI,GAAG,KAAK,CAAC;QAEpB,mCAAmC,CAC/B,MAAM,CAAC,KAAK,EAAE,YAAY,EAAE,6BAA6B,CAAC,CAAC;QAE/D,OAAO,OAAO,CAAC,MAAM,EAAE,kBAAkB,CAAC,CAAC;IAC7C,CAAC;IAED;;;OAGG;IACH,QAAQ,CAAC,MAAc;QACrB,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,MAAM,CAAC,KAAK,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QAED,mCAAmC,CAC/B,MAAM,CAAC,KAAK,EAAE,IAAI,CAAC,YAAY,EAAE,6BAA6B,CAAC,CAAC;QAEpE,IAAI,IAAI,CAAC,cAAc,KAAK,IAAI,CAAC,IAAI,EAAE,EAAE;YACvC,MAAM,IAAI,KAAK,CAAC,0CAA0C,CAAC,CAAC;SAC7D;QACD,IAAI,CAAC,MAAM,CAAC,CAAC;QACb,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;IAC5B,CAAC;IAED;;;OAGG;IACH,MAAM,CAAC,IAAY;QACjB,IAAI,IAAI,GAAG,CAAC,EAAE;YACZ,MAAM,IAAI,KAAK,CACX,0DAA0D,IAAI,EAAE,CAAC,CAAC;SACvE;QAED,IAAI,IAAI,CAAC,cAAc,KAAK,CAAC,CAAC,IAAI,IAAI,GAAG,IAAI,CAAC,cAAc,EAAE;YAC5D,MAAM,IAAI,KAAK,CAAC,+BACZ,IAAI,6BAA6B,IAAI,CAAC,cAAc,GAAG,CAAC,CAAC;SAC9D;QAED,MAAM,cAAc,GAAe,IAAI,UAAU,CAC7C,EAAE,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,cAAc,CAAC,CAAC;QACnE,cAAc,CAAC,OAAO,CAAC,MAAM,GAAG,IAAI,CAAC;QACrC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE,IAAI,CAAC,EAAE,EAAE,CAAC,EAAE;YAC5D,cAAc,CAAC,OAAO,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;SAC7C;QACD,OAAO,cAAc,CAAC;IACxB,CAAC;IAED;;;;;OAKG;IACH,OAAO,CAAC,YAAoB,EAAE,YAAsB,EAAE,YAAsB;QAE1E,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QACD,IAAI,YAAY,GAAG,CAAC,IAAI,YAAY,GAAG,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE;YAC1D,MAAM,IAAI,KAAK,CAAC,4BACZ,YAAY,mBAAmB,IAAI,CAAC,OAAO,CAAC,MAAM,YAAY,CAAC,CAAC;SACrE;QAED,IAAI,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,IAAI,IAAI,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,oBAAoB,YAAY,WAAW,CAAC,CAAC;SAC9D;QAED,mCAAmC,CAC/B,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC,KAAK,EAAE,YAAY,EAC9C,6BAA6B,CAAC,CAAC;QACnC,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QACrE,OAAO,OAAO,CAAC,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,EAAE,kBAAkB,CAAC,CAAC;IACjE,CAAC;IAED;;;;OAIG;IACH,OAAO,CAAC,YAAoB,EAAE,MAAc;QAC1C,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,MAAM,CAAC,KAAK,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QAED,IAAI,YAAY,GAAG,CAAC;YAChB,IAAI,CAAC,cAAc,KAAK,CAAC,CAAC,IAAI,YAAY,IAAI,IAAI,CAAC,cAAc,EAAE;YACrE,MAAM,IAAI,KAAK,CAAC,yBACZ,YAAY,uBAAuB,IAAI,CAAC,cAAc,YAAY,CAAC,CAAC;SACzE;QAED,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,MAAM,CAAC,KAAK,EAAE,6BAA6B,CAAC,CAAC;QACpE,IAAI,CAAC,MAAM,CAAC,CAAC;QAEb,iDAAiD;QACjD,IAAI,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,IAAI,IAAI,EAAE;YACtC,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC,IAAI,GAAG,KAAK,CAAC;SACzC;QAED,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,GAAG,MAAM,CAAC;IACtC,CAAC;IAED;;;;;;OAMG;IACH,MAAM,CAAC,OAAiB,EAAE,YAAsB,EAAE,YAAsB;QAEtE,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QAED,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,YAAY,EAAE,6BAA6B,CAAC,CAAC;QAEpE,wEAAwE;QACxE,gCAAgC;QAChC,OAAO,GAAG,OAAO,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;QACxC,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QACrE,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;YACxB,OAAO,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,kBAAkB,CAAC,CAAC,CAAC;SACnD;QAED,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,OAAO,GACT,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,kBAAkB,CAAC,CAAC,CAAC;YACnE,OAAO,KAAK,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC;QAC3B,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;;OAIG;IACH,MAAM,CAAC,YAAsB,EAAE,YAAsB;QACnD,IAAI,CAAC,CAAC,YAAY,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACxD,MAAM,IAAI,KAAK,CAAC,uBACZ,IAAI,CAAC,YAAY,+BAA+B,YAAY,EAAE,CAAC,CAAC;SACrE;QAED,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,YAAY,EAAE,6BAA6B,CAAC,CAAC;QACpE,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QAErE,IAAI,IAAI,CAAC,IAAI,EAAE,KAAK,CAAC,EAAE;YACrB,OAAO,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,kBAAkB,CAAC,CAAC,CAAC;SACnD;QACD,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,CAAC,EAAE,kBAAkB,CAAC,CAAC,CAAC;YACtE,OAAO,MAAM,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC;QAC5B,CAAC,CAAC,CAAC;IACL,CAAC;CACF;AAED;;;;GAIG;AACH,MAAM,UAAU,UAAU,CACtB,MAAc,EAAE,YAAsB,EAAE,YAAsB;IAChE,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC;IAC3B,IAAI,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE;QAC3B,MAAM,IAAI,KAAK,CACX,oDAAoD,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;KACzE;IACD,IAAI,MAAM,CAAC,KAAK,KAAK,YAAY,EAAE;QACjC,MAAM,IAAI,KAAK,CAAC,mCACZ,MAAM,CAAC,KAAK,uBAAuB,YAAY,EAAE,CAAC,CAAC;KACxD;IACD,MAAM,kBAAkB,GAAG,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IACjD,mCAAmC,CAC/B,kBAAkB,EAAE,YAAY,EAAE,6BAA6B,CAAC,CAAC;IACrE,MAAM,UAAU,GAAa,OAAO,CAAC,MAAM,CAAC,CAAC;IAC7C,OAAO,IAAI,UAAU,CAAC,UAAU,EAAE,YAAY,EAAE,KAAK,CAAC,CAAC;AACzD,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,OAAO,CACnB,YAAsB,EAAE,YAAsB,EAAE,WAAmB,EACnE,cAAsB;IACxB,OAAO,IAAI,UAAU,CAAC,EAAE,EAAE,YAAY,EAAE,YAAY,EAAE,cAAc,CAAC,CAAC;AACxE,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,OAAO,CACnB,MAAc,EAAE,OAAiB,EAAE,YAAsB,EACzD,WAAoB;IACtB,IAAI,OAAO,CAAC,MAAM,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;QACtC,MAAM,IAAI,KAAK,CAAC,sDACZ,OAAO,CAAC,MAAM,QAAQ,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;KAC9C;IAED,MAAM,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,OAAO,CAAC,CAAC;IAEtC,IAAI,WAAW,IAAI,IAAI,IAAI,WAAW,KAAK,CAAC,CAAC,IAAI,QAAQ,IAAI,WAAW,EAAE;QACxE,MAAM,IAAI,KAAK,CACX,mCAAmC,QAAQ,SAAS,WAAW,GAAG,CAAC,CAAC;KACzE;IAED,MAAM,IAAI,GAAG,IAAI,UAAU,CAAC,EAAE,EAAE,YAAY,EAAE,MAAM,CAAC,KAAK,EAAE,WAAW,CAAC,CAAC;IACzE,MAAM,OAAO,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;IACnC,OAAO,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,EAAE;QAC/B,IAAI,CAAC,OAAO,CAAC,KAAK,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC;IACtC,CAAC,CAAC,CAAC;IACH,OAAO,IAAI,CAAC;AACd,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,KAAK,CACjB,MAAc,EAAE,MAAgB,EAAE,YAAsB;IAC1D,IAAI,WAAW,GAAG,CAAC,CAAC;IACpB,MAAM,iBAAiB,GAAG,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;QACzC,WAAW,IAAI,GAAG,CAAC;QACnB,OAAO,WAAW,CAAC;IACrB,CAAC,CAAC,CAAC;IAEH,IAAI,WAAW,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;QACnC,MAAM,IAAI,KAAK,CAAC;;UAEV,WAAW,4BAA4B,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;KAC9D;IAED,MAAM,oBAAoB,GAAG,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IACnD,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,oBAAoB,EAAE,YAAY,CAAC,CAAC;IAC1D,MAAM,aAAa,GAAG,WAAW,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,GAAG,WAAW,CAAC;IACxE,MAAM,OAAO,GAAa,IAAI,CAAC,GAAG,EAAE;QAClC,MAAM,OAAO,GAAG,EAAE,CAAC;QACnB,MAAM,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,EAAE,WAAW,EAAE,aAAa,CAAC,CAAC,CAAC;QAC1D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YACtC,MAAM,cAAc,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,iBAAiB,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;YAChE,MAAM,OAAO,GAAG,CAAC,CAAC,EAAE,cAAc,EAAE,CAAC,CAAC,CAAC;YACvC,MAAM,KAAK,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,aAAa,CAAC,CAAC;YAC5C,OAAO,CAAC,CAAC,CAAC,GAAG,OAAO,CAChB,KAAK,CAAC,MAAM,EAAE,OAAO,EAAE,KAAK,CAAC,EAAE,kBAA8B,CAAC,CAAC;SACpE;QACD,MAAM,CAAC,OAAO,EAAE,CAAC;QACjB,OAAO,OAAO,CAAC;IACjB,CAAC,CAAC,CAAC;IAEH,MAAM,IAAI,GAAG,IAAI,UAAU,CAAC,EAAE,EAAE,YAAY,EAAE,MAAM,CAAC,KAAK,EAAE,MAAM,CAAC,MAAM,CAAC,CAAC;IAE3E,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;QACvC,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;KAC7B;IACD,OAAO,IAAI,CAAC;AACd,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {concat, DataType, keep, reshape, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core';\n\nimport {assertShapesMatchAllowUndefinedSize, inferElementShape, mergeElementShape} from './tensor_utils';\n\n/**\n * TensorList stores a container of `tf.Tensor` objects, which are accessible\n * via tensors field.\n *\n * In order to get a copy of the underlying list, use the copy method:\n * ```\n *    TensorList b = a.copy();\n *    b.tensors().pushBack(t);  // This does not modify a.tensors().\n * ```\n *\n * Note that this is not a deep copy: the memory locations of the underlying\n * tensors will still point to the same locations of the corresponding tensors\n * in the original.\n */\n\nexport class TensorList {\n  readonly idTensor: Tensor;\n  maxNumElements: number;\n\n  get id() {\n    return this.idTensor.id;\n  }\n  /**\n   *\n   * @param tensors list of tensors\n   * @param elementShape shape of each tensor, this can be a single number (any\n   * shape is allowed) or partial shape (dim = -1).\n   * @param elementDtype data type of each tensor\n   * @param maxNumElements The maximum allowed size of `tensors`. Defaults to -1\n   *   meaning that the size of `tensors` is unbounded.\n   */\n  constructor(\n      readonly tensors: Tensor[], readonly elementShape: number|number[],\n      readonly elementDtype: DataType, maxNumElements = -1) {\n    if (tensors != null) {\n      tensors.forEach(tensor => {\n        if (elementDtype !== tensor.dtype) {\n          throw new Error(`Invalid data types; op elements ${\n              elementDtype}, but list elements ${tensor.dtype}`);\n        }\n        assertShapesMatchAllowUndefinedSize(\n            elementShape, tensor.shape, 'TensorList shape mismatch: ');\n\n        keep(tensor);\n      });\n    }\n    this.idTensor = scalar(0);\n    this.maxNumElements = maxNumElements;\n    keep(this.idTensor);\n  }\n\n  /**\n   * Get a new TensorList containing a copy of the underlying tensor container.\n   */\n  copy(): TensorList {\n    return new TensorList(\n        [...this.tensors], this.elementShape, this.elementDtype);\n  }\n\n  /**\n   * Dispose the tensors and idTensor and clear the tensor list.\n   */\n  clearAndClose(keepIds?: Set<number>) {\n    this.tensors.forEach(tensor => {\n      if (keepIds == null || !keepIds.has(tensor.id)) {\n        tensor.dispose();\n      }\n    });\n    this.tensors.length = 0;\n    this.idTensor.dispose();\n  }\n  /**\n   * The size of the tensors in the tensor list.\n   */\n  size() {\n    return this.tensors.length;\n  }\n\n  /**\n   * Return a tensor that stacks a list of rank-R tf.Tensors into one rank-(R+1)\n   * tf.Tensor.\n   * @param elementShape shape of each tensor\n   * @param elementDtype data type of each tensor\n   * @param numElements the number of elements to stack\n   */\n  stack(elementShape: number[], elementDtype: DataType, numElements = -1):\n      Tensor {\n    if (elementDtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          elementDtype}, but list elements ${this.elementDtype}`);\n    }\n    if (numElements !== -1 && this.tensors.length !== numElements) {\n      throw new Error(`Operation expected a list with ${\n          numElements} elements but got a list with ${\n          this.tensors.length} elements.`);\n    }\n    assertShapesMatchAllowUndefinedSize(\n        elementShape, this.elementShape, 'TensorList shape mismatch: ');\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n    return tidy(() => {\n      const reshapedTensors =\n          this.tensors.map(tensor => reshape(tensor, outputElementShape));\n      return stack(reshapedTensors, 0);\n    });\n  }\n\n  /**\n   * Pop a tensor from the end of the list.\n   * @param elementShape shape of the tensor\n   * @param elementDtype data type of the tensor\n   */\n  popBack(elementShape: number[], elementDtype: DataType): Tensor {\n    if (elementDtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          elementDtype}, but list elements ${this.elementDtype}`);\n    }\n\n    if (this.size() === 0) {\n      throw new Error('Trying to pop from an empty list.');\n    }\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n    const tensor = this.tensors.pop();\n    tensor.kept = false;\n\n    assertShapesMatchAllowUndefinedSize(\n        tensor.shape, elementShape, 'TensorList shape mismatch: ');\n\n    return reshape(tensor, outputElementShape);\n  }\n\n  /**\n   * Push a tensor to the end of the list.\n   * @param tensor Tensor to be pushed.\n   */\n  pushBack(tensor: Tensor) {\n    if (tensor.dtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          tensor.dtype}, but list elements ${this.elementDtype}`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        tensor.shape, this.elementShape, 'TensorList shape mismatch: ');\n\n    if (this.maxNumElements === this.size()) {\n      throw new Error(`Trying to push element into a full list.`);\n    }\n    keep(tensor);\n    this.tensors.push(tensor);\n  }\n\n  /**\n   * Update the size of the list.\n   * @param size the new size of the list.\n   */\n  resize(size: number) {\n    if (size < 0) {\n      throw new Error(\n          `TensorListResize expects size to be non-negative. Got: ${size}`);\n    }\n\n    if (this.maxNumElements !== -1 && size > this.maxNumElements) {\n      throw new Error(`TensorListResize input size ${\n          size} is greater maxNumElement ${this.maxNumElements}.`);\n    }\n\n    const destTensorList: TensorList = new TensorList(\n        [], this.elementShape, this.elementDtype, this.maxNumElements);\n    destTensorList.tensors.length = size;\n    for (let i = 0; i < Math.min(this.tensors.length, size); ++i) {\n      destTensorList.tensors[i] = this.tensors[i];\n    }\n    return destTensorList;\n  }\n\n  /**\n   * Retrieve the element at the provided index\n   * @param elementShape shape of the tensor\n   * @param elementDtype dtype of the tensor\n   * @param elementIndex index of the tensor\n   */\n  getItem(elementIndex: number, elementShape: number[], elementDtype: DataType):\n      Tensor {\n    if (elementDtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          elementDtype}, but list elements ${this.elementDtype}`);\n    }\n    if (elementIndex < 0 || elementIndex > this.tensors.length) {\n      throw new Error(`Trying to access element ${\n          elementIndex} in a list with ${this.tensors.length} elements.`);\n    }\n\n    if (this.tensors[elementIndex] == null) {\n      throw new Error(`element at index ${elementIndex} is null.`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.tensors[elementIndex].shape, elementShape,\n        'TensorList shape mismatch: ');\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n    return reshape(this.tensors[elementIndex], outputElementShape);\n  }\n\n  /**\n   * Set the tensor at the index\n   * @param elementIndex index of the tensor\n   * @param tensor the tensor to be inserted into the list\n   */\n  setItem(elementIndex: number, tensor: Tensor) {\n    if (tensor.dtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          tensor.dtype}, but list elements ${this.elementDtype}`);\n    }\n\n    if (elementIndex < 0 ||\n        this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {\n      throw new Error(`Trying to set element ${\n          elementIndex} in a list with max ${this.maxNumElements} elements.`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, tensor.shape, 'TensorList shape mismatch: ');\n    keep(tensor);\n\n    // dispose the previous value if it is replacing.\n    if (this.tensors[elementIndex] != null) {\n      this.tensors[elementIndex].kept = false;\n    }\n\n    this.tensors[elementIndex] = tensor;\n  }\n\n  /**\n   * Return selected values in the TensorList as a stacked Tensor. All of\n   * selected values must have been written and their shapes must all match.\n   * @param indices indices of tensors to gather\n   * @param elementDtype output tensor dtype\n   * @param elementShape output tensor element shape\n   */\n  gather(indices: number[], elementDtype: DataType, elementShape: number[]):\n      Tensor {\n    if (elementDtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          elementDtype}, but list elements ${this.elementDtype}`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, elementShape, 'TensorList shape mismatch: ');\n\n    // When indices is greater than the size of the list, indices beyond the\n    // size of the list are ignored.\n    indices = indices.slice(0, this.size());\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n    if (indices.length === 0) {\n      return tensor([], [0].concat(outputElementShape));\n    }\n\n    return tidy(() => {\n      const tensors =\n          indices.map(i => reshape(this.tensors[i], outputElementShape));\n      return stack(tensors, 0);\n    });\n  }\n\n  /**\n   * Return the values in the TensorList as a concatenated Tensor.\n   * @param elementDtype output tensor dtype\n   * @param elementShape output tensor element shape\n   */\n  concat(elementDtype: DataType, elementShape: number[]): Tensor {\n    if (!!elementDtype && elementDtype !== this.elementDtype) {\n      throw new Error(`TensorList dtype is ${\n          this.elementDtype} but concat requested dtype ${elementDtype}`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, elementShape, 'TensorList shape mismatch: ');\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n\n    if (this.size() === 0) {\n      return tensor([], [0].concat(outputElementShape));\n    }\n    return tidy(() => {\n      const tensors = this.tensors.map(t => reshape(t, outputElementShape));\n      return concat(tensors, 0);\n    });\n  }\n}\n\n/**\n * Creates a TensorList which, when stacked, has the value of tensor.\n * @param tensor from tensor\n * @param elementShape output tensor element shape\n */\nexport function fromTensor(\n    tensor: Tensor, elementShape: number[], elementDtype: DataType) {\n  const dtype = tensor.dtype;\n  if (tensor.shape.length < 1) {\n    throw new Error(\n        `Tensor must be at least a vector, but saw shape: ${tensor.shape}`);\n  }\n  if (tensor.dtype !== elementDtype) {\n    throw new Error(`Invalid data types; op elements ${\n        tensor.dtype}, but list elements ${elementDtype}`);\n  }\n  const tensorElementShape = tensor.shape.slice(1);\n  assertShapesMatchAllowUndefinedSize(\n      tensorElementShape, elementShape, 'TensorList shape mismatch: ');\n  const tensorList: Tensor[] = unstack(tensor);\n  return new TensorList(tensorList, elementShape, dtype);\n}\n\n/**\n * Return a TensorList of the given size with empty elements.\n * @param elementShape the shape of the future elements of the list\n * @param elementDtype the desired type of elements in the list\n * @param numElements the number of elements to reserve\n * @param maxNumElements the maximum number of elements in th list\n */\nexport function reserve(\n    elementShape: number[], elementDtype: DataType, numElements: number,\n    maxNumElements: number) {\n  return new TensorList([], elementShape, elementDtype, maxNumElements);\n}\n\n/**\n * Put tensors at specific indices of a stacked tensor into a TensorList.\n * @param indices list of indices on how to scatter the tensor.\n * @param tensor input tensor.\n * @param elementShape the shape of the future elements of the list\n * @param numElements the number of elements to scatter\n */\nexport function scatter(\n    tensor: Tensor, indices: number[], elementShape: number[],\n    numElements?: number): TensorList {\n  if (indices.length !== tensor.shape[0]) {\n    throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${\n        indices.length} vs. ${tensor.shape[0]}`);\n  }\n\n  const maxIndex = Math.max(...indices);\n\n  if (numElements != null && numElements !== -1 && maxIndex >= numElements) {\n    throw new Error(\n        `Max index must be < array size (${maxIndex}  vs. ${numElements})`);\n  }\n\n  const list = new TensorList([], elementShape, tensor.dtype, numElements);\n  const tensors = unstack(tensor, 0);\n  indices.forEach((value, index) => {\n    list.setItem(value, tensors[index]);\n  });\n  return list;\n}\n\n/**\n * Split the values of a Tensor into a TensorList.\n * @param length the lengths to use when splitting value along\n *    its first dimension.\n * @param tensor the tensor to split.\n * @param elementShape the shape of the future elements of the list\n */\nexport function split(\n    tensor: Tensor, length: number[], elementShape: number[]) {\n  let totalLength = 0;\n  const cumulativeLengths = length.map(len => {\n    totalLength += len;\n    return totalLength;\n  });\n\n  if (totalLength !== tensor.shape[0]) {\n    throw new Error(`Expected sum of lengths to be equal to\n          tensor.shape[0], but sum of lengths is\n        ${totalLength}, and tensor's shape is: ${tensor.shape}`);\n  }\n\n  const shapeWithoutFirstDim = tensor.shape.slice(1);\n  const outputElementShape =\n      mergeElementShape(shapeWithoutFirstDim, elementShape);\n  const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;\n  const tensors: Tensor[] = tidy(() => {\n    const tensors = [];\n    tensor = reshape(tensor, [1, totalLength, elementPerRow]);\n    for (let i = 0; i < length.length; ++i) {\n      const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];\n      const indices = [0, previousLength, 0];\n      const sizes = [1, length[i], elementPerRow];\n      tensors[i] = reshape(\n          slice(tensor, indices, sizes), outputElementShape as number[]);\n    }\n    tensor.dispose();\n    return tensors;\n  });\n\n  const list = new TensorList([], elementShape, tensor.dtype, length.length);\n\n  for (let i = 0; i < tensors.length; i++) {\n    list.setItem(i, tensors[i]);\n  }\n  return list;\n}\n"]}