UNPKG

@tensorflow-models/coco-ssd

Version:

Object detection model (coco-ssd) in TensorFlow.js

167 lines (148 loc) 5.22 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {Tensor} from '@tensorflow/tfjs-core'; import {NamedTensorsMap, TensorArrayMap} from '../data/types'; import {TensorArray} from './tensor_array'; export interface ExecutionContextInfo { id: number; // the unique id of the context info frameName: string; // The frame name of the loop, this comes from // the TensorFlow NodeDef. iterationId: number; // The iteration id of the loop } /** * ExecutionContext captures the runtime environment of the node. It keeps * track of the current frame and iteration for the control flow ops. * * For example, typical Dynamic RNN model may contain loops, for which * TensorFlow will generate graphs with Enter/Exit nodes to control the * current execution frame, and NextIteration Nodes for iteration id increment. * For model with branch logic, TensorFLow will generate Switch/Merge ops. */ export class ExecutionContext { private rootContext = {id: 0, frameName: '', iterationId: 0}; private contexts: ExecutionContextInfo[] = [this.rootContext]; private lastId = 0; private _currentContextIds: string[]; constructor( public readonly weightMap: NamedTensorsMap, public readonly tensorArrayMap: TensorArrayMap) { this.generateCurrentContextIds(); } private newFrame(id: number, frameName: string) { return {id, frameName, iterationId: 0}; } /** * Set the current context * @param contexts: ExecutionContextInfo[] the current path of execution * frames */ set currentContext(contexts: ExecutionContextInfo[]) { if (this.contexts !== contexts) { this.contexts = contexts; this.generateCurrentContextIds(); } } get currentContext(): ExecutionContextInfo[] { return this.contexts; } /** * Returns the current context in string format. */ get currentContextId(): string { return this._currentContextIds[0]; } /** * Returns the current context and all parent contexts in string format. * This allow access to the nodes in the current and parent frames. */ get currentContextIds(): string[] { return this._currentContextIds; } private generateCurrentContextIds() { const names = []; for (let i = 0; i < this.contexts.length - 1; i++) { const contexts = this.contexts.slice(0, this.contexts.length - i); names.push(this.contextIdforContexts(contexts)); } names.push(''); this._currentContextIds = names; } private contextIdforContexts(contexts: ExecutionContextInfo[]) { return contexts ? contexts .map( context => (context.id === 0 && context.iterationId === 0) ? '' : `${context.frameName}-${context.iterationId}`) .join('/') : ''; } /** * Enter a new frame, a new context is pushed on the current context list. * @param frameId new frame id */ enterFrame(frameId: string) { if (this.contexts) { this.lastId++; this.contexts = this.contexts.slice(); this.contexts.push(this.newFrame(this.lastId, frameId)); this._currentContextIds.unshift(this.contextIdforContexts(this.contexts)); } } /** * Exit the current frame, the last context is removed from the current * context list. */ exitFrame() { if (this.contexts && this.contexts.length > 1) { this.contexts = this.contexts.slice(); this.contexts.splice(-1); this.currentContextIds.shift(); } else { throw new Error('Cannot exit frame, the context is empty'); } } /** * Enter the next iteration of a loop, the iteration id of last context is * increased. */ nextIteration() { if (this.contexts && this.contexts.length > 0) { this.contexts = this.contexts.slice(); this.lastId++; const context = Object.assign({}, this.contexts[this.contexts.length - 1]) as ExecutionContextInfo; context.iterationId += 1; context.id = this.lastId; this.contexts.splice(-1, 1, context); this._currentContextIds.splice( 0, 1, this.contextIdforContexts(this.contexts)); } else { throw new Error('Cannot increase frame iteration, the context is empty'); } } getWeight(name: string): Tensor[] { return this.weightMap[name]; } addTensorArray(tensorArray: TensorArray) { this.tensorArrayMap[tensorArray.id] = tensorArray; } getTensorArray(id: number): TensorArray { return this.tensorArrayMap[id]; } }