UNPKG

@tensorflow-models/coco-ssd

Version:

Object detection model (coco-ssd) in TensorFlow.js

324 lines (279 loc) 10.1 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // tslint:disable-next-line:max-line-length import {concat, DataType, slice, stack, Tensor, tensor, tidy, unstack, util} from '@tensorflow/tfjs-core'; export interface TensorWithState { tensor?: Tensor; written?: boolean; read?: boolean; cleared?: boolean; } /** * The TensorArray object keeps an array of Tensors. It * allows reading from the array and writing to the array. */ export class TensorArray { private static nextId = 0; private tensors: TensorWithState[] = []; private closed_ = false; readonly id: number; constructor( public readonly name: string, public readonly dtype: DataType, private maxSize: number, private elementShape: number[], public readonly identicalElementShapes: boolean, public readonly dynamicSize: boolean, public readonly clearAfterRead: boolean) { this.id = TensorArray.nextId++; } get closed() { return this.closed_; } /** * Close the current TensorArray. */ clearAndClose() { this.tensors.forEach(tensor => tensor.tensor.dispose()); this.tensors = []; this.closed_ = true; } size(): number { return this.tensors.length; } /** * Read the value at location index in the TensorArray. * @param index Number the index to read from. */ read(index: number): Tensor { if (this.closed_) { throw new Error(`TensorArray ${this.name} has already been closed.`); } if (index < 0 || index >= this.tensors.length) { throw new Error(`Tried to read from index ${index}, but array size is: ${ this.tensors.length}`); } const tensorWithState = this.tensors[index]; if (tensorWithState.cleared) { throw new Error( `TensorArray ${this.name}: Could not read index ${ index} twice because it was cleared after a previous read ` + `(perhaps try setting clear_after_read = false?).`); } if (this.clearAfterRead) { tensorWithState.cleared = true; } tensorWithState.read = true; return tensorWithState.tensor; } /** * Helper method to read multiple tensors from the specified indices. */ readMany(indices: number[]): Tensor[] { return indices.map(index => this.read(index)); } /** * Write value into the index of the TensorArray. * @param index number the index to write to. * @param tensor */ write(index: number, tensor: Tensor) { if (this.closed_) { throw new Error(`TensorArray ${this.name} has already been closed.`); } if (index < 0 || !this.dynamicSize && index >= this.maxSize) { throw new Error(`Tried to write to index ${ index}, but array is not resizeable and size is: ${this.maxSize}`); } const t = this.tensors[index] || {}; if (tensor.dtype !== this.dtype) { throw new Error(`TensorArray ${ this.name}: Could not write to TensorArray index ${index}, because the value dtype is ${ tensor.dtype}, but TensorArray dtype is ${this.dtype}.`); } // Set the shape for the first time write to unknow shape tensor array if (this.size() === 0 && (this.elementShape == null || this.elementShape.length === 0)) { this.elementShape = tensor.shape; } this.assertShapesMatch( this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${ index}.`); if (t && t.read) { throw new Error( `TensorArray ${this.name}: Could not write to TensorArray index ${ index}, because it has already been read.`); } if (t && t.written) { throw new Error( `TensorArray ${this.name}: Could not write to TensorArray index ${ index}, because it has already been written.`); } t.tensor = tensor; t.written = true; this.tensors[index] = t; } /** * Helper method to write multiple tensors to the specified indices. */ writeMany(indices: number[], tensors: Tensor[]) { if (indices.length !== tensors.length) { throw new Error( `TensorArray ${this.name}: could not write multiple tensors,` + `because the index size: ${ indices.length} is not the same as tensors size: ${ tensors.length}.`); } indices.forEach((i, index) => this.write(i, tensors[index])); } /** * Return selected values in the TensorArray as a packed Tensor. All of * selected values must have been written and their shapes must all match. * @param [indices] number[] Optional. Taking values in [0, max_value). If the * TensorArray is not dynamic, max_value=size(). If not specified returns * all tensors in the original order. * @param [dtype] */ gather(indices?: number[], dtype?: DataType): Tensor { if (!!dtype && dtype !== this.dtype) { throw new Error(`TensorArray dtype is ${ this.dtype} but gather requested dtype ${dtype}`); } if (!indices) { indices = []; for (let i = 0; i < this.size(); i++) { indices.push(i); } } if (indices.length === 0) { return tensor([], [0].concat(this.elementShape)); } // Read all the PersistentTensors into a vector to keep track of // their memory. const tensors = this.readMany(indices); this.assertShapesMatch( this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: '); return stack(tensors, 0); } /** * Return the values in the TensorArray as a concatenated Tensor. */ concat(dtype?: DataType): Tensor { if (!!dtype && dtype !== this.dtype) { throw new Error(`TensorArray dtype is ${ this.dtype} but concat requested dtype ${dtype}`); } if (this.size() === 0) { return tensor([], [0].concat(this.elementShape)); } const indices = []; for (let i = 0; i < this.size(); i++) { indices.push(i); } // Collect all the tensors from the tensors array. const tensors = this.readMany(indices); this.assertShapesMatch( this.elementShape, tensors[0].shape, `TensorArray shape mismatch: tensor array shape (${ this.elementShape}) vs first tensor shape (${tensors[0].shape})`); return concat(tensors, 0); } /** * Scatter the values of a Tensor in specific indices of a TensorArray. * @param indices nummber[] values in [0, max_value). If the * TensorArray is not dynamic, max_value=size(). * @param tensor Tensor input tensor. */ scatter(indices: number[], tensor: Tensor) { if (tensor.dtype !== this.dtype) { throw new Error(`TensorArray dtype is ${ this.dtype} but tensor has dtype ${tensor.dtype}`); } if (indices.length !== tensor.shape[0]) { throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${ indices.length} vs. ${tensor.shape[0]}`); } const maxIndex = Math.max(...indices); if (!this.dynamicSize && maxIndex >= this.maxSize) { throw new Error( `Max index must be < array size (${maxIndex} vs. ${this.maxSize})`); } this.writeMany(indices, unstack(tensor, 0)); } /** * Split the values of a Tensor into the TensorArray. * @param length number[] with the lengths to use when splitting value along * its first dimension. * @param tensor Tensor, the tensor to split. */ split(length: number[], tensor: Tensor) { if (tensor.dtype !== this.dtype) { throw new Error(`TensorArray dtype is ${ this.dtype} but tensor has dtype ${tensor.dtype}`); } let totalLength = 0; const cumulativeLengths = length.map(len => { totalLength += len; return totalLength; }); if (totalLength !== tensor.shape[0]) { throw new Error(`Expected sum of lengths to be equal to tensor.shape[0], but sum of lengths is ${totalLength}, and tensor's shape is: ${tensor.shape}`); } if (!this.dynamicSize && length.length !== this.maxSize) { throw new Error( `TensorArray's size is not equal to the size of lengths (${ this.maxSize} vs. ${length.length}), ` + 'and the TensorArray is not marked as dynamically resizeable'); } const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength; const tensors: Tensor[] = []; tidy(() => { tensor = tensor.reshape([1, totalLength, elementPerRow]); for (let i = 0; i < length.length; ++i) { const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1]; const indices = [0, previousLength, 0]; const sizes = [1, length[i], elementPerRow]; tensors[i] = slice(tensor, indices, sizes).reshape(this.elementShape); } return tensors; }); const indices = []; for (let i = 0; i < length.length; i++) { indices[i] = i; } this.writeMany(indices, tensors); } private assertShapesMatch( shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void { util.assert( this.arraysEqual(shapeA, shapeB), errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); } private arraysEqual(n1: number[], n2: number[]) { if (n1.length !== n2.length) { return false; } for (let i = 0; i < n1.length; i++) { if (n1[i] !== -1 && n2[i] !== -1 && n1[i] !== n2[i]) { return false; } } return true; } }